| | """ |
| | FastAPI Backend for Trading Game + AI Chatbot Integration |
| | Supports tunable AI parameters for psychological experiments |
| | """ |
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.staticfiles import StaticFiles |
| | from pydantic import BaseModel |
| | from typing import Optional, List, Dict |
| | import json |
| | import os |
| | from datetime import datetime |
| |
|
| | |
| | from app import chain, llm, ChatOpenAI, PromptTemplate |
| | from langchain_core.prompts import ChatPromptTemplate |
| |
|
| | app = FastAPI(title="Trading Game AI Experiment API") |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | app.mount("/game", StaticFiles(directory="game", html=True), name="game") |
| |
|
| | |
| | experiment_state = {} |
| | game_sessions = {} |
| |
|
| | class AIMessageRequest(BaseModel): |
| | question: str |
| | chat_history: List[tuple] = [] |
| | risk_level: float = 5.0 |
| | temperature: float = 0.7 |
| | confidence_boost: float = 0.0 |
| | session_id: str = "default" |
| |
|
| | class TradingDecision(BaseModel): |
| | session_id: str |
| | symbol: str |
| | action: str |
| | quantity: int |
| | price: float |
| | ai_advice_followed: bool |
| | trust_score: Optional[float] = None |
| |
|
| | class ScenarioTrigger(BaseModel): |
| | session_id: str |
| | scenario_type: str |
| | context: Dict |
| | ai_required: bool = True |
| |
|
| | def get_ai_with_params(risk_level: float, temperature: float, confidence_boost: float): |
| | """Create LLM with tunable parameters for experiment""" |
| | |
| | adjusted_temp = temperature + (risk_level / 10.0) * 0.3 |
| | |
| | |
| | import os |
| | model_name = os.getenv("HF_MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct:novita") |
| | api_key = os.getenv("HF_TOKEN") |
| | if not api_key: |
| | raise ValueError("HF_TOKEN environment variable is not set.") |
| | |
| | |
| | tuned_llm = ChatOpenAI( |
| | model=model_name, |
| | base_url="https://router.huggingface.co/v1", |
| | api_key=api_key, |
| | temperature=min(adjusted_temp, 2.0), |
| | max_tokens=512, |
| | ) |
| | return tuned_llm |
| |
|
| | def get_contextual_prompt(context: Dict, risk_level: float, confidence_boost: float): |
| | """Generate prompt that reflects risk level and confidence boost""" |
| | risk_descriptor = { |
| | 0: "extremely conservative", |
| | 2: "very conservative", |
| | 4: "conservative", |
| | 5: "moderate", |
| | 6: "moderately aggressive", |
| | 8: "aggressive", |
| | 10: "very aggressive" |
| | } |
| | risk_text = risk_descriptor.get(int(risk_level), "moderate") |
| | |
| | confidence_text = "" |
| | if confidence_boost > 20: |
| | confidence_text = "You are highly confident in your recommendations." |
| | elif confidence_boost > 0: |
| | confidence_text = "You are confident in your recommendations." |
| | elif confidence_boost < -20: |
| | confidence_text = "You are uncertain and should express caution in your recommendations." |
| | |
| | base_template = """ |
| | You are an AI trading advisor for the Quantum Financial Network. Your risk profile is {risk_level}. |
| | {confidence_text} |
| | |
| | You provide trading advice based on market data. Consider: |
| | - Current market conditions: {market_context} |
| | - Player portfolio status: {portfolio_context} |
| | - Recent events: {events_context} |
| | |
| | Answer the question with appropriate caution/certainty based on your risk profile. |
| | |
| | Context: |
| | {{context}} |
| | |
| | Question: {{question}} |
| | |
| | Answer: |
| | """ |
| | |
| | return PromptTemplate( |
| | input_variables=["context", "question"], |
| | template=base_template.format( |
| | risk_level=risk_text, |
| | confidence_text=confidence_text, |
| | market_context=context.get("market", "normal conditions"), |
| | portfolio_context=context.get("portfolio", "standard portfolio"), |
| | events_context=context.get("events", "no major events") |
| | ) |
| | ) |
| |
|
| | @app.post("/api/ai/chat") |
| | async def chat_with_ai(request: AIMessageRequest): |
| | """Main AI chat endpoint with tunable parameters""" |
| | try: |
| | |
| | if request.session_id not in game_sessions: |
| | game_sessions[request.session_id] = { |
| | "chat_history": [], |
| | "decisions": [], |
| | "trust_scores": [], |
| | "params_history": [] |
| | } |
| | |
| | session = game_sessions[request.session_id] |
| | |
| | |
| | |
| | chat_history_tuples = [] |
| | if request.chat_history: |
| | for item in request.chat_history: |
| | if isinstance(item, (list, tuple)) and len(item) >= 2: |
| | chat_history_tuples.append((str(item[0]), str(item[1]))) |
| | |
| | |
| | |
| | enhanced_question = request.question |
| | |
| | |
| | if "risk tolerance" not in enhanced_question.lower() and request.risk_level is not None: |
| | risk_desc = { |
| | 0: "extremely conservative", |
| | 1: "very conservative", |
| | 3: "conservative", |
| | 5: "moderate", |
| | 7: "moderately aggressive", |
| | 9: "aggressive", |
| | 10: "very aggressive" |
| | } |
| | risk_text = risk_desc.get(int(request.risk_level), "moderate") |
| | |
| | if "Current Market Scenario" not in enhanced_question: |
| | enhanced_question = f"Risk Profile: {risk_text} ({request.risk_level}/10)\n\n{enhanced_question}" |
| | |
| | |
| | result = chain({ |
| | "question": enhanced_question, |
| | "chat_history": chat_history_tuples |
| | }) |
| | |
| | |
| | interaction = { |
| | "timestamp": datetime.now().isoformat(), |
| | "question": request.question, |
| | "response": result["answer"], |
| | "risk_level": request.risk_level, |
| | "temperature": request.temperature, |
| | "confidence_boost": request.confidence_boost |
| | } |
| | session["params_history"].append(interaction) |
| | |
| | return { |
| | "answer": result["answer"], |
| | "sources": [doc.page_content[:100] for doc in result.get("source_documents", [])] if "source_documents" in result else [], |
| | "interaction_id": len(session["params_history"]) - 1 |
| | } |
| | |
| | except Exception as e: |
| | import traceback |
| | error_detail = str(e) + "\n" + traceback.format_exc() |
| | raise HTTPException(status_code=500, detail=error_detail) |
| |
|
| | @app.post("/api/experiment/decision") |
| | async def log_decision(decision: TradingDecision): |
| | """Log player trading decisions for trust analysis""" |
| | if decision.session_id not in game_sessions: |
| | game_sessions[decision.session_id] = { |
| | "chat_history": [], |
| | "decisions": [], |
| | "trust_scores": [], |
| | "params_history": [] |
| | } |
| | |
| | game_sessions[decision.session_id]["decisions"].append({ |
| | "timestamp": datetime.now().isoformat(), |
| | "symbol": decision.symbol, |
| | "action": decision.action, |
| | "quantity": decision.quantity, |
| | "price": decision.price, |
| | "ai_advice_followed": decision.ai_advice_followed, |
| | "trust_score": decision.trust_score |
| | }) |
| | |
| | return {"status": "logged", "decision_id": len(game_sessions[decision.session_id]["decisions"]) - 1} |
| |
|
| | @app.post("/api/experiment/scenario") |
| | async def trigger_scenario(scenario: ScenarioTrigger): |
| | """Trigger situational scenarios that require AI assistance""" |
| | scenario_prompts = { |
| | "volatility": "The market is experiencing high volatility. A stock in your portfolio has moved 10% in the last hour. What should you do?", |
| | "large_position": "You're about to make a large position trade ($10,000+). This would represent a significant portion of your portfolio. Should you proceed?", |
| | "loss_recovery": "You're down 5% today. Would you like advice on whether to cut losses or hold your positions?", |
| | "news_event": "Breaking news just released that affects several stocks in your watchlist. How should this impact your trading decisions?" |
| | } |
| | |
| | if scenario.scenario_type not in scenario_prompts: |
| | raise HTTPException(status_code=400, detail="Unknown scenario type") |
| | |
| | prompt = scenario_prompts[scenario.scenario_type] |
| | |
| | |
| | return { |
| | "scenario_type": scenario.scenario_type, |
| | "prompt": prompt, |
| | "context": scenario.context, |
| | "requires_ai": scenario.ai_required |
| | } |
| |
|
| | @app.get("/api/experiment/session/{session_id}") |
| | async def get_session_data(session_id: str): |
| | """Get all experiment data for a session""" |
| | if session_id not in game_sessions: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| | |
| | return game_sessions[session_id] |
| |
|
| | @app.get("/api/experiment/export/{session_id}") |
| | async def export_experiment_data(session_id: str): |
| | """Export experiment data as JSON for analysis""" |
| | if session_id not in game_sessions: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| | |
| | data = game_sessions[session_id] |
| | return data |
| |
|
| | @app.get("/") |
| | async def root(): |
| | """Redirect to game""" |
| | from fastapi.responses import RedirectResponse |
| | return RedirectResponse(url="/game/trade.html") |
| |
|
| | @app.get("/game") |
| | async def game_redirect(): |
| | """Redirect to game""" |
| | from fastapi.responses import RedirectResponse |
| | return RedirectResponse(url="/game/trade.html") |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|