Spaces:
Paused
Paused
| import os | |
| import uuid | |
| import json | |
| import sqlite3 | |
| import httpx | |
| import requests | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.responses import HTMLResponse, PlainTextResponse, Response, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from gradio import Server | |
| # Import static strings from bag.py | |
| from bag import ( | |
| BASE_URL, | |
| LLMS_TXT, | |
| SITEMAP_XML, | |
| ROBOTS_TXT, | |
| OVERSEER_JSON, | |
| VIDEO_PAGE_HTML | |
| ) | |
| app = FastAPI() | |
| # --- Database helpers --- | |
| DATA_DIR = "data" | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| def get_db_path(session_id: str) -> str: | |
| return os.path.join(DATA_DIR, f"session_{session_id}.db") | |
| def init_session_db(session_id: str): | |
| db_path = get_db_path(session_id) | |
| conn = sqlite3.connect(db_path) | |
| conn.execute('''CREATE TABLE IF NOT EXISTS nodes ( | |
| id TEXT PRIMARY KEY, | |
| parent_id TEXT, | |
| node_type TEXT NOT NULL, | |
| label TEXT NOT NULL, | |
| description TEXT DEFAULT '', | |
| emoji TEXT DEFAULT '', | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| )''') | |
| conn.execute('''CREATE TABLE IF NOT EXISTS roots ( | |
| id TEXT PRIMARY KEY, | |
| decision TEXT NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| )''') | |
| # Ensure root node exists | |
| root = conn.execute("SELECT id FROM roots LIMIT 1").fetchone() | |
| if not root: | |
| root_id = str(uuid.uuid4()) | |
| conn.execute("INSERT INTO roots (id, decision) VALUES (?, 'New Decision')", (root_id,)) | |
| conn.execute("INSERT INTO nodes (id, parent_id, node_type, label, description) VALUES (?, NULL, 'root', 'What decision do you want to explore?', 'Enter a decision at the top of the page to begin.')", (root_id,)) | |
| conn.commit() | |
| conn.close() | |
| def get_tree_nested(session_id: str) -> dict: | |
| db_path = get_db_path(session_id) | |
| conn = sqlite3.connect(db_path) | |
| conn.row_factory = sqlite3.Row | |
| rows = conn.execute("SELECT * FROM nodes ORDER BY created_at").fetchall() | |
| conn.close() | |
| # Build tree recursively | |
| node_map = {} | |
| for row in rows: | |
| node_map[row['id']] = { | |
| 'id': row['id'], | |
| 'parent_id': row['parent_id'], | |
| 'type': row['node_type'], | |
| 'label': row['label'], | |
| 'description': row['description'], | |
| 'emoji': row['emoji'], | |
| 'children': [] | |
| } | |
| root = None | |
| for nid, node in node_map.items(): | |
| if node['parent_id'] is None: | |
| root = node | |
| else: | |
| parent = node_map.get(node['parent_id']) | |
| if parent: | |
| parent['children'].append(node) | |
| return root or {'id': 'error', 'label': 'No root found', 'children': []} | |
| def build_path_string(session_id: str, node_id: str) -> str: | |
| db_path = get_db_path(session_id) | |
| conn = sqlite3.connect(db_path) | |
| conn.row_factory = sqlite3.Row | |
| path_parts = [] | |
| current_id = node_id | |
| while current_id: | |
| row = conn.execute("SELECT id, parent_id, node_type, label FROM nodes WHERE id=?", (current_id,)).fetchone() | |
| if not row: | |
| break | |
| path_parts.append(f"[{row['node_type'].upper()}] {row['label']}") | |
| current_id = row['parent_id'] | |
| conn.close() | |
| path_parts.reverse() | |
| return " → ".join(path_parts) if path_parts else node_id | |
| def get_node_db(session_id: str, node_id: str) -> dict: | |
| db_path = get_db_path(session_id) | |
| conn = sqlite3.connect(db_path) | |
| conn.row_factory = sqlite3.Row | |
| row = conn.execute("SELECT * FROM nodes WHERE id=?", (node_id,)).fetchone() | |
| conn.close() | |
| if row: | |
| return dict(row) | |
| return None | |
| def add_node_db(session_id: str, parent_id: str, node_type: str, label: str, description: str = "", emoji: str = ""): | |
| db_path = get_db_path(session_id) | |
| conn = sqlite3.connect(db_path) | |
| node_id = str(uuid.uuid4()) | |
| conn.execute( | |
| "INSERT INTO nodes (id, parent_id, node_type, label, description, emoji) VALUES (?, ?, ?, ?, ?, ?)", | |
| (node_id, parent_id, node_type, label, description, emoji) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| return node_id | |
| # --- AI Generation --- | |
| DEFAULT_MODEL = "nvidia/nemotron-3-nano-30b-a3b" | |
| OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "") | |
| def call_api(prompt: str, max_tokens: int = 1024) -> str: | |
| if not OPENROUTER_API_KEY: | |
| raise HTTPException(status_code=500, detail="OPENROUTER_API_KEY not set") | |
| response = requests.post( | |
| url="https://openrouter.ai/api/v1/chat/completions", | |
| headers={ | |
| "Authorization": f"Bearer {OPENROUTER_API_KEY}", | |
| "Content-Type": "application/json" | |
| }, | |
| json={ | |
| "model": DEFAULT_MODEL, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": max_tokens, | |
| "temperature": 0.8 | |
| }, | |
| timeout=60 | |
| ) | |
| if response.status_code != 200: | |
| raise HTTPException(status_code=500, detail=f"API error: {response.status_code} - {response.text}") | |
| data = response.json() | |
| choices = data.get("choices", []) | |
| if not choices: | |
| raise HTTPException(status_code=500, detail="No choices in response") | |
| return choices[0].get("message", {}).get("content", "") | |
| def parse_children(text: str) -> list: | |
| """Parse AI response into list of dicts with label, description, emoji.""" | |
| children = [] | |
| try: | |
| # Try JSON parsing first | |
| data = json.loads(text) | |
| if isinstance(data, list): | |
| children = data | |
| elif isinstance(data, dict) and "children" in data: | |
| children = data["children"] | |
| except json.JSONDecodeError: | |
| # Fallback: split by lines | |
| lines = text.strip().split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line.startswith('-') or line.startswith('*'): | |
| label = line[1:].strip() | |
| if label: | |
| children.append({"label": label, "description": "", "emoji": ""}) | |
| return children | |
| def build_options_prompt(path_context: str, parent_label: str, parent_desc: str, count: int, comment: str) -> str: | |
| return f"""You are generating OPTIONS (choices/decisions) for a decision tree. | |
| Full path from root to this node: | |
| {path_context} | |
| Current node: {parent_label} | |
| Description: {parent_desc} | |
| Generate {count} distinct, creative options that follow from this node. Each option should be a possible action, choice, or path forward that makes sense given the full context above. | |
| CRITICAL: Respond ONLY with a valid JSON array of objects. Each object must have: | |
| - "label": A short, punchy title (2-6 words) | |
| - "description": 1-2 sentence explanation of this option | |
| - "emoji": A single emoji character representing this option | |
| Example: | |
| [ | |
| {{"label": "Start freelancing", "description": "Begin working independently as a freelancer", "emoji": "💼"}}, | |
| {{"label": "Take a course", "description": "Enroll in a structured learning program", "emoji": "📚"}} | |
| ] | |
| IMPORTANT: Your response must be ONLY the JSON array. No markdown, no explanations, no code blocks.""" | |
| def build_outcomes_prompt(path_context: str, parent_label: str, parent_desc: str, count: int, comment: str) -> str: | |
| return f"""You are generating OUTCOMES (results/consequences) for a decision tree. | |
| Full path from root to this node: | |
| {path_context} | |
| Current node: {parent_label} | |
| Description: {parent_desc} | |
| Generate {count} distinct, realistic outcomes that could result from this choice. Each outcome should feel like a natural consequence given the full decision history above. | |
| CRITICAL: Respond ONLY with a valid JSON array of objects. Each object must have: | |
| - "label": A short, punchy title (2-6 words) | |
| - "description": 1-2 sentence explanation of this outcome | |
| - "emoji": A single emoji character representing this outcome | |
| Example: | |
| [ | |
| {{"label": "Financial stability improves", "description": "The freelancer enjoys a steady income over time", "emoji": "💰"}}, | |
| {{"label": "Loneliness sets in", "description": "Working alone leads to feelings of isolation", "emoji": "😔"}} | |
| ] | |
| IMPORTANT: Your response must be ONLY the JSON array. No markdown, no explanations, no code blocks.""" | |
| # --- API Endpoints --- | |
| async def get_llms_txt(): | |
| return PlainTextResponse(LLMS_TXT) | |
| async def get_sitemap(): | |
| return Response(content=SITEMAP_XML, media_type="application/xml") | |
| async def get_robots(): | |
| return PlainTextResponse(ROBOTS_TXT) | |
| async def get_overthinker_json(): | |
| return Response(content=OVERSEER_JSON, media_type="application/json") | |
| async def get_video(): | |
| return HTMLResponse(content=VIDEO_PAGE_HTML) | |
| async def create_root(request: Request): | |
| body = await request.json() | |
| session_id = body.get("session_id", str(uuid.uuid4())) | |
| decision = body.get("decision", "") | |
| init_session_db(session_id) | |
| db_path = get_db_path(session_id) | |
| conn = sqlite3.connect(db_path) | |
| if decision: | |
| conn.execute("UPDATE roots SET decision=? WHERE rowid=1", (decision,)) | |
| root_row = conn.execute("SELECT id FROM roots LIMIT 1").fetchone() | |
| if root_row: | |
| conn.execute("UPDATE nodes SET label=? WHERE id=?", (decision, root_row[0])) | |
| conn.commit() | |
| conn.close() | |
| tree = get_tree_nested(session_id) | |
| path = build_path_string(session_id, tree['id']) | |
| return {"session_id": session_id, "tree": tree, "path": path} | |
| async def get_children(request: Request): | |
| body = await request.json() | |
| session_id = body.get("session_id") | |
| node_id = body.get("node_id") | |
| count = body.get("count", 3) | |
| node_type = body.get("node_type", "outcome") | |
| comment = body.get("comment", "") | |
| init_session_db(session_id) | |
| parent = get_node_db(session_id, node_id) | |
| if not parent: | |
| raise HTTPException(status_code=404, detail="Node not found") | |
| parent_label = parent.get('label', 'Unknown') | |
| parent_desc = parent.get('description', '') | |
| path_context = build_path_string(session_id, node_id) | |
| next_type = "input" if node_type == "outcome" else "outcome" | |
| if next_type == 'input': | |
| prompt = build_options_prompt(path_context, parent_label, parent_desc, count, comment) | |
| else: | |
| prompt = build_outcomes_prompt(path_context, parent_label, parent_desc, count, comment) | |
| try: | |
| text = call_api(prompt, max_tokens=2048) | |
| children = parse_children(text) | |
| if not children: | |
| raise HTTPException(status_code=500, detail="Generation failed. AI returned empty results.") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| # Save children to DB | |
| child_ids = [] | |
| for child in children: | |
| cid = add_node_db(session_id, node_id, next_type, child.get('label', ''), child.get('description', ''), child.get('emoji', '')) | |
| child_ids.append(cid) | |
| # Fetch saved children | |
| db_path = get_db_path(session_id) | |
| conn = sqlite3.connect(db_path) | |
| conn.row_factory = sqlite3.Row | |
| saved_children = [] | |
| for cid in child_ids: | |
| row = conn.execute("SELECT * FROM nodes WHERE id=?", (cid,)).fetchone() | |
| if row: | |
| saved_children.append(dict(row)) | |
| conn.close() | |
| parent_label = parent.get('label', '') | |
| parent_desc = parent.get('description', '') | |
| path_context = build_path_string(session_id, node_id) | |
| next_type = "input" if node_type == "outcome" else "outcome" | |
| return { | |
| "children": saved_children, | |
| "parent_label": parent_label, | |
| "parent_desc": parent_desc, | |
| "path_context": path_context, | |
| "next_type": next_type | |
| } | |
| async def add_options(request: Request): | |
| body = await request.json() | |
| session_id = body.get("session_id") | |
| node_id = body.get("node_id") | |
| count = body.get("count", 3) | |
| comment = body.get("comment", "") | |
| init_session_db(session_id) | |
| parent = get_node_db(session_id, node_id) | |
| if not parent: | |
| raise HTTPException(status_code=404, detail="Node not found") | |
| parent_label = parent.get('label', '') | |
| parent_desc = parent.get('description', '') | |
| path_context = build_path_string(session_id, node_id) | |
| next_type = "input" if parent['node_type'] == "outcome" else "outcome" | |
| if next_type == 'input': | |
| prompt = build_options_prompt(path_context, parent_label, parent_desc, count, comment) | |
| else: | |
| prompt = build_outcomes_prompt(path_context, parent_label, parent_desc, count, comment) | |
| try: | |
| text = call_api(prompt, max_tokens=2048) | |
| children = parse_children(text) | |
| if not children: | |
| raise HTTPException(status_code=500, detail="Generation failed. AI returned empty results.") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| child_ids = [] | |
| for child in children: | |
| cid = add_node_db(session_id, node_id, next_type, child.get('label', ''), child.get('description', ''), child.get('emoji', '')) | |
| child_ids.append(cid) | |
| db_path = get_db_path(session_id) | |
| conn = sqlite3.connect(db_path) | |
| conn.row_factory = sqlite3.Row | |
| saved_children = [] | |
| for cid in child_ids: | |
| row = conn.execute("SELECT * FROM nodes WHERE id=?", (cid,)).fetchone() | |
| if row: | |
| saved_children.append(dict(row)) | |
| conn.close() | |
| return { | |
| "children": saved_children, | |
| "parent_label": parent_label, | |
| "parent_desc": parent_desc, | |
| "path_context": path_context, | |
| "next_type": next_type | |
| } | |
| async def upload_trace(request: Request): | |
| body = await request.json() | |
| session_id = body.get("session_id") | |
| if not session_id: | |
| raise HTTPException(status_code=400, detail="session_id required") | |
| tree = get_tree_nested(session_id) | |
| if not tree: | |
| raise HTTPException(status_code=404, detail="No tree found") | |
| # Upload to Hugging Face Dataset via REST API | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| dataset_repo = os.environ.get("HF_DATASET_REPO", "build-small-hackathon/Overthinker-trace") | |
| if not hf_token or not dataset_repo: | |
| raise HTTPException(status_code=500, detail="HF_TOKEN or HF_DATASET_REPO not set") | |
| import json as json_module | |
| trace_data = json_module.dumps(tree, indent=2) | |
| filename = f"trace_{session_id}.json" | |
| url = f"https://huggingface.co/api/datasets/{dataset_repo}/upload" | |
| files = {'file': (filename, trace_data, 'application/json')} | |
| headers = {'Authorization': f'Bearer {hf_token}'} | |
| response = requests.post(url, headers=headers, files=files) | |
| if response.status_code not in (200, 201): | |
| raise HTTPException(status_code=500, detail=f"Upload failed: {response.status_code} - {response.text}") | |
| return {"status": "ok", "filename": filename} | |
| # --- Serve static frontend --- | |
| app.mount("/", StaticFiles(directory="templates", html=True), name="templates") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |