Update app.py
Browse files
app.py
CHANGED
|
@@ -1,31 +1,108 @@
|
|
|
|
|
| 1 |
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from fastapi import FastAPI
|
| 3 |
-
from fastapi.responses import StreamingResponse
|
| 4 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# --- FastAPI App with CORS ---
|
| 8 |
-
app = FastAPI(title="
|
| 9 |
app.add_middleware(
|
| 10 |
CORSMiddleware,
|
| 11 |
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
|
| 12 |
)
|
| 13 |
|
| 14 |
-
# ---
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# --- Server Startup ---
|
| 31 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
import os
|
| 2 |
import asyncio
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from datetime import datetime, timezone
|
| 7 |
from fastapi import FastAPI
|
|
|
|
|
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
import httpx
|
| 11 |
+
import trafilatura
|
| 12 |
+
import google.generativeai as genai
|
| 13 |
|
| 14 |
# --- FastAPI App with CORS ---
|
| 15 |
+
app = FastAPI(title="AI Research Agent API")
|
| 16 |
app.add_middleware(
|
| 17 |
CORSMiddleware,
|
| 18 |
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
|
| 19 |
)
|
| 20 |
|
| 21 |
+
# --- Prompts ---
|
| 22 |
+
PROMPT_NORMAL = """Concisely summarize the key points from the following text based on the user's query: "{query}". Focus on the most critical information. PROVIDED TEXT: --- {context_text} ---"""
|
| 23 |
+
PROMPT_DEEP = """As a research analyst, synthesize the information from the provided texts into a detailed report. Current Date: {current_date}. User's Query: "{query}". Instructions: Create a detailed report, combining facts from all sources. Cite source URLs inline, like this: (Source: http://...). At the end, create a "## Sources" section listing all unique URLs. Use clear markdown. Provided Texts: --- {context_text} ---"""
|
| 24 |
+
PROMPT_ULTRADEEP_PLANNER = """You are a research planner. Based on the user's query, create a research plan. Your output MUST be a valid JSON object like this: {"queries": ["query 1", "query 2"]}. Do not add any other text. USER'S QUERY: "{query}" """
|
| 25 |
+
PROMPT_ULTRADEEP_SYNTHESIZER = """You are a master research analyst. Synthesize the collected text into a single, comprehensive, well-structured report based on the user's original query: "{query}". Current Date: {current_date}. Instructions: Synthesize a logical narrative organized by topic. If critical info is missing, you can suggest it, but generate the best possible report with the available info. Cite source URLs inline `(Source: http://...)` and conclude with a "## Sources" list. Collected Raw Text: --- {context_text} ---"""
|
| 26 |
+
|
| 27 |
+
# --- Core Logic with Better Error Handling ---
|
| 28 |
+
async def search_web_logic(query: str, serper_api_key: str) -> str:
|
| 29 |
+
if not serper_api_key: return "Error: Serper API Key is missing."
|
| 30 |
+
try:
|
| 31 |
+
headers = {"X-API-KEY": serper_api_key, "Content-Type": "application/json"}
|
| 32 |
+
async with httpx.AsyncClient(timeout=15) as client:
|
| 33 |
+
resp = await client.post("https://google.serper.dev/search", headers=headers, json={"q": query, "num": 7})
|
| 34 |
+
if resp.status_code == 401: return "Error: Invalid Serper API Key."
|
| 35 |
+
if resp.status_code != 200: return f"Error: Serper API returned status {resp.status_code}."
|
| 36 |
+
results = resp.json().get("organic", [])
|
| 37 |
+
if not results: return f"Error: No web results found for query '{query}'."
|
| 38 |
+
urls = [r["link"] for r in results]
|
| 39 |
+
async with httpx.AsyncClient(timeout=20, follow_redirects=True) as client:
|
| 40 |
+
tasks = [client.get(u) for u in urls]; responses = await asyncio.gather(*tasks, return_exceptions=True)
|
| 41 |
+
texts = [f"Source URL: {meta['link']}\nContent: {body.strip()}\n" for meta, response in zip(results, responses) if not isinstance(response, Exception) and (body := trafilatura.extract(response.text))]
|
| 42 |
+
if not texts: return "Error: Found web results, but could not extract content."
|
| 43 |
+
return "\n---\n".join(texts)
|
| 44 |
+
except Exception as e: return f"Error during web search: {str(e)}"
|
| 45 |
+
|
| 46 |
+
async def call_gemini(prompt: str, gemini_key: str, model_name: str, json_mode: bool = False) -> str:
|
| 47 |
+
if not gemini_key: return json.dumps({"error": "Gemini API Key is missing."})
|
| 48 |
+
try:
|
| 49 |
+
genai.configure(api_key=gemini_key)
|
| 50 |
+
model = genai.GenerativeModel(model_name)
|
| 51 |
+
generation_config = {"response_mime_type": "application/json"} if json_mode else None
|
| 52 |
+
response = await model.generate_content_async(prompt, generation_config=generation_config)
|
| 53 |
+
return response.text
|
| 54 |
+
except Exception as e: return json.dumps({"error": f"Error calling Gemini: {str(e)}"})
|
| 55 |
+
|
| 56 |
+
# --- AI Agent Logic (Non-Streaming) ---
|
| 57 |
+
async def ultradeep_research_agent(query: str, serper_api_key: str, gemini_key: str, model_name: str) -> str:
|
| 58 |
+
# Step 1: Plan
|
| 59 |
+
planner_prompt = PROMPT_ULTRADEEP_PLANNER.format(query=query)
|
| 60 |
+
plan_str = await call_gemini(planner_prompt, gemini_key, model_name, json_mode=True)
|
| 61 |
+
try:
|
| 62 |
+
match = re.search(r'\{.*\}', plan_str, re.DOTALL)
|
| 63 |
+
if not match: raise ValueError("No JSON object found in Gemini's planner response.")
|
| 64 |
+
plan_data = json.loads(match.group(0))
|
| 65 |
+
if "error" in plan_data: return f"Error during planning phase: {plan_data['error']}"
|
| 66 |
+
search_queries = plan_data["queries"]
|
| 67 |
+
except Exception as e: return f"Error: Could not create a valid research plan. Details: {str(e)}\nRaw Response: {plan_str}"
|
| 68 |
|
| 69 |
+
# Step 2: Execute
|
| 70 |
+
tasks = [search_web_logic(sub_query, serper_api_key) for sub_query in search_queries]
|
| 71 |
+
search_results = await asyncio.gather(*tasks)
|
| 72 |
+
all_scraped_text = "\n".join([res for res in search_results if not res.startswith("Error:")])
|
| 73 |
+
if not all_scraped_text: return "Error: Could not retrieve any web content for the planned queries. Check Serper key."
|
| 74 |
+
|
| 75 |
+
# Step 3: Synthesize
|
| 76 |
+
current_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
| 77 |
+
synthesizer_prompt = PROMPT_ULTRADEEP_SYNTHESIZER.format(query=query, current_date=current_date, context_text=all_scraped_text)
|
| 78 |
+
final_report = await call_gemini(synthesizer_prompt, gemini_key, model_name)
|
| 79 |
+
return final_report
|
| 80 |
+
|
| 81 |
+
# --- The Single, Unified FastAPI Endpoint ---
|
| 82 |
+
class ResearchRequest(BaseModel):
|
| 83 |
+
query: str
|
| 84 |
+
serper_api_key: str
|
| 85 |
+
gemini_api_key: str
|
| 86 |
+
research_mode: str
|
| 87 |
+
gemini_model: str = "gemini-1.5-flash-latest"
|
| 88 |
+
|
| 89 |
+
@app.post("/api/research")
|
| 90 |
+
async def api_research(request: ResearchRequest):
|
| 91 |
+
if request.research_mode == 'ultradeep':
|
| 92 |
+
result = await ultradeep_research_agent(
|
| 93 |
+
request.query, request.serper_api_key, request.gemini_api_key, request.gemini_model
|
| 94 |
+
)
|
| 95 |
+
else: # Normal and Deep modes
|
| 96 |
+
scraped_text = await search_web_logic(request.query, request.serper_api_key)
|
| 97 |
+
if scraped_text.startswith("Error:"):
|
| 98 |
+
return {"result": scraped_text}
|
| 99 |
+
|
| 100 |
+
current_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
| 101 |
+
prompt_template = PROMPT_DEEP if request.research_mode == "deep" else PROMPT_NORMAL
|
| 102 |
+
final_prompt = prompt_template.format(query=request.query, context_text=scraped_text, current_date=current_date)
|
| 103 |
+
result = await call_gemini(final_prompt, request.gemini_api_key, request.gemini_model)
|
| 104 |
+
|
| 105 |
+
return {"result": result}
|
| 106 |
|
| 107 |
# --- Server Startup ---
|
| 108 |
if __name__ == "__main__":
|