JoshBot / game_api.py
LittleMonkeyLab's picture
Upload 19 files
e44e7cc verified
"""
FastAPI Backend for Trading Game + AI Chatbot Integration
Supports tunable AI parameters for psychological experiments
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import Optional, List, Dict
import json
import os
from datetime import datetime
# Import your existing chain from app.py
from app import chain, llm, ChatOpenAI, PromptTemplate
from langchain_core.prompts import ChatPromptTemplate
app = FastAPI(title="Trading Game AI Experiment API")
# CORS for local development
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve static files from game directory
app.mount("/game", StaticFiles(directory="game", html=True), name="game")
# Experiment state storage (in production, use a database)
experiment_state = {}
game_sessions = {}
class AIMessageRequest(BaseModel):
question: str
chat_history: List[tuple] = []
risk_level: float = 5.0 # 0-10, affects advice certainty
temperature: float = 0.7 # 0.0-2.0, affects response randomness
confidence_boost: float = 0.0 # -100 to +100, manipulates trustworthiness
session_id: str = "default"
class TradingDecision(BaseModel):
session_id: str
symbol: str
action: str # "buy" or "sell"
quantity: int
price: float
ai_advice_followed: bool
trust_score: Optional[float] = None # 1-10, explicit trust rating
class ScenarioTrigger(BaseModel):
session_id: str
scenario_type: str # "volatility", "large_position", "loss_recovery", "news_event"
context: Dict
ai_required: bool = True
def get_ai_with_params(risk_level: float, temperature: float, confidence_boost: float):
"""Create LLM with tunable parameters for experiment"""
# Adjust temperature based on risk level (higher risk = more variability)
adjusted_temp = temperature + (risk_level / 10.0) * 0.3
# Get model name from environment or use default
import os
model_name = os.getenv("HF_MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct:novita")
api_key = os.getenv("HF_TOKEN")
if not api_key:
raise ValueError("HF_TOKEN environment variable is not set.")
# Create new LLM instance with adjusted parameters
tuned_llm = ChatOpenAI(
model=model_name,
base_url="https://router.huggingface.co/v1",
api_key=api_key,
temperature=min(adjusted_temp, 2.0), # Cap at 2.0
max_tokens=512,
)
return tuned_llm
def get_contextual_prompt(context: Dict, risk_level: float, confidence_boost: float):
"""Generate prompt that reflects risk level and confidence boost"""
risk_descriptor = {
0: "extremely conservative",
2: "very conservative",
4: "conservative",
5: "moderate",
6: "moderately aggressive",
8: "aggressive",
10: "very aggressive"
}
risk_text = risk_descriptor.get(int(risk_level), "moderate")
confidence_text = ""
if confidence_boost > 20:
confidence_text = "You are highly confident in your recommendations."
elif confidence_boost > 0:
confidence_text = "You are confident in your recommendations."
elif confidence_boost < -20:
confidence_text = "You are uncertain and should express caution in your recommendations."
base_template = """
You are an AI trading advisor for the Quantum Financial Network. Your risk profile is {risk_level}.
{confidence_text}
You provide trading advice based on market data. Consider:
- Current market conditions: {market_context}
- Player portfolio status: {portfolio_context}
- Recent events: {events_context}
Answer the question with appropriate caution/certainty based on your risk profile.
Context:
{{context}}
Question: {{question}}
Answer:
"""
return PromptTemplate(
input_variables=["context", "question"],
template=base_template.format(
risk_level=risk_text,
confidence_text=confidence_text,
market_context=context.get("market", "normal conditions"),
portfolio_context=context.get("portfolio", "standard portfolio"),
events_context=context.get("events", "no major events")
)
)
@app.post("/api/ai/chat")
async def chat_with_ai(request: AIMessageRequest):
"""Main AI chat endpoint with tunable parameters"""
try:
# Get or initialize session
if request.session_id not in game_sessions:
game_sessions[request.session_id] = {
"chat_history": [],
"decisions": [],
"trust_scores": [],
"params_history": []
}
session = game_sessions[request.session_id]
# Use the existing chain from app.py for now (simplified for POC)
# Convert chat_history from list of tuples to proper format if needed
chat_history_tuples = []
if request.chat_history:
for item in request.chat_history:
if isinstance(item, (list, tuple)) and len(item) >= 2:
chat_history_tuples.append((str(item[0]), str(item[1])))
# Enhance the question with scenario context and slider parameters
# Extract scenario context if present in question
enhanced_question = request.question
# Add risk tolerance context to question if not already included
if "risk tolerance" not in enhanced_question.lower() and request.risk_level is not None:
risk_desc = {
0: "extremely conservative",
1: "very conservative",
3: "conservative",
5: "moderate",
7: "moderately aggressive",
9: "aggressive",
10: "very aggressive"
}
risk_text = risk_desc.get(int(request.risk_level), "moderate")
if "Current Market Scenario" not in enhanced_question:
enhanced_question = f"Risk Profile: {risk_text} ({request.risk_level}/10)\n\n{enhanced_question}"
# Get AI response using the existing chain
result = chain({
"question": enhanced_question,
"chat_history": chat_history_tuples
})
# Log interaction for experiment
interaction = {
"timestamp": datetime.now().isoformat(),
"question": request.question,
"response": result["answer"],
"risk_level": request.risk_level,
"temperature": request.temperature,
"confidence_boost": request.confidence_boost
}
session["params_history"].append(interaction)
return {
"answer": result["answer"],
"sources": [doc.page_content[:100] for doc in result.get("source_documents", [])] if "source_documents" in result else [],
"interaction_id": len(session["params_history"]) - 1
}
except Exception as e:
import traceback
error_detail = str(e) + "\n" + traceback.format_exc()
raise HTTPException(status_code=500, detail=error_detail)
@app.post("/api/experiment/decision")
async def log_decision(decision: TradingDecision):
"""Log player trading decisions for trust analysis"""
if decision.session_id not in game_sessions:
game_sessions[decision.session_id] = {
"chat_history": [],
"decisions": [],
"trust_scores": [],
"params_history": []
}
game_sessions[decision.session_id]["decisions"].append({
"timestamp": datetime.now().isoformat(),
"symbol": decision.symbol,
"action": decision.action,
"quantity": decision.quantity,
"price": decision.price,
"ai_advice_followed": decision.ai_advice_followed,
"trust_score": decision.trust_score
})
return {"status": "logged", "decision_id": len(game_sessions[decision.session_id]["decisions"]) - 1}
@app.post("/api/experiment/scenario")
async def trigger_scenario(scenario: ScenarioTrigger):
"""Trigger situational scenarios that require AI assistance"""
scenario_prompts = {
"volatility": "The market is experiencing high volatility. A stock in your portfolio has moved 10% in the last hour. What should you do?",
"large_position": "You're about to make a large position trade ($10,000+). This would represent a significant portion of your portfolio. Should you proceed?",
"loss_recovery": "You're down 5% today. Would you like advice on whether to cut losses or hold your positions?",
"news_event": "Breaking news just released that affects several stocks in your watchlist. How should this impact your trading decisions?"
}
if scenario.scenario_type not in scenario_prompts:
raise HTTPException(status_code=400, detail="Unknown scenario type")
prompt = scenario_prompts[scenario.scenario_type]
# Return scenario data for frontend to trigger AI chat
return {
"scenario_type": scenario.scenario_type,
"prompt": prompt,
"context": scenario.context,
"requires_ai": scenario.ai_required
}
@app.get("/api/experiment/session/{session_id}")
async def get_session_data(session_id: str):
"""Get all experiment data for a session"""
if session_id not in game_sessions:
raise HTTPException(status_code=404, detail="Session not found")
return game_sessions[session_id]
@app.get("/api/experiment/export/{session_id}")
async def export_experiment_data(session_id: str):
"""Export experiment data as JSON for analysis"""
if session_id not in game_sessions:
raise HTTPException(status_code=404, detail="Session not found")
data = game_sessions[session_id]
return data
@app.get("/")
async def root():
"""Redirect to game"""
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/game/trade.html")
@app.get("/game")
async def game_redirect():
"""Redirect to game"""
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/game/trade.html")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)