Spaces:
Running
Running
| """ | |
| FastAPI Backend for Trading Game + AI Chatbot Integration | |
| Supports tunable AI parameters for psychological experiments | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Dict | |
| import json | |
| import os | |
| from datetime import datetime | |
| # Import your existing chain from app.py | |
| from app import chain, llm, ChatOpenAI, PromptTemplate | |
| from langchain_core.prompts import ChatPromptTemplate | |
| app = FastAPI(title="Trading Game AI Experiment API") | |
| # CORS for local development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Serve static files from game directory | |
| app.mount("/game", StaticFiles(directory="game", html=True), name="game") | |
| # Experiment state storage (in production, use a database) | |
| experiment_state = {} | |
| game_sessions = {} | |
| class AIMessageRequest(BaseModel): | |
| question: str | |
| chat_history: List[tuple] = [] | |
| risk_level: float = 5.0 # 0-10, affects advice certainty | |
| temperature: float = 0.7 # 0.0-2.0, affects response randomness | |
| confidence_boost: float = 0.0 # -100 to +100, manipulates trustworthiness | |
| session_id: str = "default" | |
| class TradingDecision(BaseModel): | |
| session_id: str | |
| symbol: str | |
| action: str # "buy" or "sell" | |
| quantity: int | |
| price: float | |
| ai_advice_followed: bool | |
| trust_score: Optional[float] = None # 1-10, explicit trust rating | |
| class ScenarioTrigger(BaseModel): | |
| session_id: str | |
| scenario_type: str # "volatility", "large_position", "loss_recovery", "news_event" | |
| context: Dict | |
| ai_required: bool = True | |
| def get_ai_with_params(risk_level: float, temperature: float, confidence_boost: float): | |
| """Create LLM with tunable parameters for experiment""" | |
| # Adjust temperature based on risk level (higher risk = more variability) | |
| adjusted_temp = temperature + (risk_level / 10.0) * 0.3 | |
| # Get model name from environment or use default | |
| import os | |
| model_name = os.getenv("HF_MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct:novita") | |
| api_key = os.getenv("HF_TOKEN") | |
| if not api_key: | |
| raise ValueError("HF_TOKEN environment variable is not set.") | |
| # Create new LLM instance with adjusted parameters | |
| tuned_llm = ChatOpenAI( | |
| model=model_name, | |
| base_url="https://router.huggingface.co/v1", | |
| api_key=api_key, | |
| temperature=min(adjusted_temp, 2.0), # Cap at 2.0 | |
| max_tokens=512, | |
| ) | |
| return tuned_llm | |
| def get_contextual_prompt(context: Dict, risk_level: float, confidence_boost: float): | |
| """Generate prompt that reflects risk level and confidence boost""" | |
| risk_descriptor = { | |
| 0: "extremely conservative", | |
| 2: "very conservative", | |
| 4: "conservative", | |
| 5: "moderate", | |
| 6: "moderately aggressive", | |
| 8: "aggressive", | |
| 10: "very aggressive" | |
| } | |
| risk_text = risk_descriptor.get(int(risk_level), "moderate") | |
| confidence_text = "" | |
| if confidence_boost > 20: | |
| confidence_text = "You are highly confident in your recommendations." | |
| elif confidence_boost > 0: | |
| confidence_text = "You are confident in your recommendations." | |
| elif confidence_boost < -20: | |
| confidence_text = "You are uncertain and should express caution in your recommendations." | |
| base_template = """ | |
| You are an AI trading advisor for the Quantum Financial Network. Your risk profile is {risk_level}. | |
| {confidence_text} | |
| You provide trading advice based on market data. Consider: | |
| - Current market conditions: {market_context} | |
| - Player portfolio status: {portfolio_context} | |
| - Recent events: {events_context} | |
| Answer the question with appropriate caution/certainty based on your risk profile. | |
| Context: | |
| {{context}} | |
| Question: {{question}} | |
| Answer: | |
| """ | |
| return PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=base_template.format( | |
| risk_level=risk_text, | |
| confidence_text=confidence_text, | |
| market_context=context.get("market", "normal conditions"), | |
| portfolio_context=context.get("portfolio", "standard portfolio"), | |
| events_context=context.get("events", "no major events") | |
| ) | |
| ) | |
| async def chat_with_ai(request: AIMessageRequest): | |
| """Main AI chat endpoint with tunable parameters""" | |
| try: | |
| # Get or initialize session | |
| if request.session_id not in game_sessions: | |
| game_sessions[request.session_id] = { | |
| "chat_history": [], | |
| "decisions": [], | |
| "trust_scores": [], | |
| "params_history": [] | |
| } | |
| session = game_sessions[request.session_id] | |
| # Use the existing chain from app.py for now (simplified for POC) | |
| # Convert chat_history from list of tuples to proper format if needed | |
| chat_history_tuples = [] | |
| if request.chat_history: | |
| for item in request.chat_history: | |
| if isinstance(item, (list, tuple)) and len(item) >= 2: | |
| chat_history_tuples.append((str(item[0]), str(item[1]))) | |
| # Enhance the question with scenario context and slider parameters | |
| # Extract scenario context if present in question | |
| enhanced_question = request.question | |
| # Add risk tolerance context to question if not already included | |
| if "risk tolerance" not in enhanced_question.lower() and request.risk_level is not None: | |
| risk_desc = { | |
| 0: "extremely conservative", | |
| 1: "very conservative", | |
| 3: "conservative", | |
| 5: "moderate", | |
| 7: "moderately aggressive", | |
| 9: "aggressive", | |
| 10: "very aggressive" | |
| } | |
| risk_text = risk_desc.get(int(request.risk_level), "moderate") | |
| if "Current Market Scenario" not in enhanced_question: | |
| enhanced_question = f"Risk Profile: {risk_text} ({request.risk_level}/10)\n\n{enhanced_question}" | |
| # Get AI response using the existing chain | |
| result = chain({ | |
| "question": enhanced_question, | |
| "chat_history": chat_history_tuples | |
| }) | |
| # Log interaction for experiment | |
| interaction = { | |
| "timestamp": datetime.now().isoformat(), | |
| "question": request.question, | |
| "response": result["answer"], | |
| "risk_level": request.risk_level, | |
| "temperature": request.temperature, | |
| "confidence_boost": request.confidence_boost | |
| } | |
| session["params_history"].append(interaction) | |
| return { | |
| "answer": result["answer"], | |
| "sources": [doc.page_content[:100] for doc in result.get("source_documents", [])] if "source_documents" in result else [], | |
| "interaction_id": len(session["params_history"]) - 1 | |
| } | |
| except Exception as e: | |
| import traceback | |
| error_detail = str(e) + "\n" + traceback.format_exc() | |
| raise HTTPException(status_code=500, detail=error_detail) | |
| async def log_decision(decision: TradingDecision): | |
| """Log player trading decisions for trust analysis""" | |
| if decision.session_id not in game_sessions: | |
| game_sessions[decision.session_id] = { | |
| "chat_history": [], | |
| "decisions": [], | |
| "trust_scores": [], | |
| "params_history": [] | |
| } | |
| game_sessions[decision.session_id]["decisions"].append({ | |
| "timestamp": datetime.now().isoformat(), | |
| "symbol": decision.symbol, | |
| "action": decision.action, | |
| "quantity": decision.quantity, | |
| "price": decision.price, | |
| "ai_advice_followed": decision.ai_advice_followed, | |
| "trust_score": decision.trust_score | |
| }) | |
| return {"status": "logged", "decision_id": len(game_sessions[decision.session_id]["decisions"]) - 1} | |
| async def trigger_scenario(scenario: ScenarioTrigger): | |
| """Trigger situational scenarios that require AI assistance""" | |
| scenario_prompts = { | |
| "volatility": "The market is experiencing high volatility. A stock in your portfolio has moved 10% in the last hour. What should you do?", | |
| "large_position": "You're about to make a large position trade ($10,000+). This would represent a significant portion of your portfolio. Should you proceed?", | |
| "loss_recovery": "You're down 5% today. Would you like advice on whether to cut losses or hold your positions?", | |
| "news_event": "Breaking news just released that affects several stocks in your watchlist. How should this impact your trading decisions?" | |
| } | |
| if scenario.scenario_type not in scenario_prompts: | |
| raise HTTPException(status_code=400, detail="Unknown scenario type") | |
| prompt = scenario_prompts[scenario.scenario_type] | |
| # Return scenario data for frontend to trigger AI chat | |
| return { | |
| "scenario_type": scenario.scenario_type, | |
| "prompt": prompt, | |
| "context": scenario.context, | |
| "requires_ai": scenario.ai_required | |
| } | |
| async def get_session_data(session_id: str): | |
| """Get all experiment data for a session""" | |
| if session_id not in game_sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return game_sessions[session_id] | |
| async def export_experiment_data(session_id: str): | |
| """Export experiment data as JSON for analysis""" | |
| if session_id not in game_sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| data = game_sessions[session_id] | |
| return data | |
| async def root(): | |
| """Redirect to game""" | |
| from fastapi.responses import RedirectResponse | |
| return RedirectResponse(url="/game/trade.html") | |
| async def game_redirect(): | |
| """Redirect to game""" | |
| from fastapi.responses import RedirectResponse | |
| return RedirectResponse(url="/game/trade.html") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |