bk939448 commited on
Commit
d93d1aa
·
verified ·
1 Parent(s): 9dcdc67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -17
app.py CHANGED
@@ -1,31 +1,108 @@
 
1
  import asyncio
 
 
 
 
2
  from fastapi import FastAPI
3
- from fastapi.responses import StreamingResponse
4
- from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
 
 
 
 
6
 
7
  # --- FastAPI App with CORS ---
8
- app = FastAPI(title="Streaming Test App")
9
  app.add_middleware(
10
  CORSMiddleware,
11
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
12
  )
13
 
14
- # --- Dummy Streaming Function ---
15
- async def dummy_streamer():
16
- yield "STATUS: Connection established! Starting test...\n"
17
- await asyncio.sleep(2)
18
-
19
- for i in range(1, 6):
20
- yield f"MESSAGE: Ping #{i} from server.\n"
21
- await asyncio.sleep(1)
22
-
23
- yield "FINAL: Test complete. Connection is working!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # --- Test Endpoint ---
26
- @app.post("/api/test-stream")
27
- async def api_test_stream():
28
- return StreamingResponse(dummy_streamer(), media_type="text/event-stream")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # --- Server Startup ---
31
  if __name__ == "__main__":
 
1
+ import os
2
  import asyncio
3
+ import re
4
+ import json
5
+ from typing import Optional
6
+ from datetime import datetime, timezone
7
  from fastapi import FastAPI
 
 
8
  from pydantic import BaseModel
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ import httpx
11
+ import trafilatura
12
+ import google.generativeai as genai
13
 
14
  # --- FastAPI App with CORS ---
15
+ app = FastAPI(title="AI Research Agent API")
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
19
  )
20
 
21
+ # --- Prompts ---
22
+ PROMPT_NORMAL = """Concisely summarize the key points from the following text based on the user's query: "{query}". Focus on the most critical information. PROVIDED TEXT: --- {context_text} ---"""
23
+ PROMPT_DEEP = """As a research analyst, synthesize the information from the provided texts into a detailed report. Current Date: {current_date}. User's Query: "{query}". Instructions: Create a detailed report, combining facts from all sources. Cite source URLs inline, like this: (Source: http://...). At the end, create a "## Sources" section listing all unique URLs. Use clear markdown. Provided Texts: --- {context_text} ---"""
24
+ PROMPT_ULTRADEEP_PLANNER = """You are a research planner. Based on the user's query, create a research plan. Your output MUST be a valid JSON object like this: {"queries": ["query 1", "query 2"]}. Do not add any other text. USER'S QUERY: "{query}" """
25
+ PROMPT_ULTRADEEP_SYNTHESIZER = """You are a master research analyst. Synthesize the collected text into a single, comprehensive, well-structured report based on the user's original query: "{query}". Current Date: {current_date}. Instructions: Synthesize a logical narrative organized by topic. If critical info is missing, you can suggest it, but generate the best possible report with the available info. Cite source URLs inline `(Source: http://...)` and conclude with a "## Sources" list. Collected Raw Text: --- {context_text} ---"""
26
+
27
+ # --- Core Logic with Better Error Handling ---
28
+ async def search_web_logic(query: str, serper_api_key: str) -> str:
29
+ if not serper_api_key: return "Error: Serper API Key is missing."
30
+ try:
31
+ headers = {"X-API-KEY": serper_api_key, "Content-Type": "application/json"}
32
+ async with httpx.AsyncClient(timeout=15) as client:
33
+ resp = await client.post("https://google.serper.dev/search", headers=headers, json={"q": query, "num": 7})
34
+ if resp.status_code == 401: return "Error: Invalid Serper API Key."
35
+ if resp.status_code != 200: return f"Error: Serper API returned status {resp.status_code}."
36
+ results = resp.json().get("organic", [])
37
+ if not results: return f"Error: No web results found for query '{query}'."
38
+ urls = [r["link"] for r in results]
39
+ async with httpx.AsyncClient(timeout=20, follow_redirects=True) as client:
40
+ tasks = [client.get(u) for u in urls]; responses = await asyncio.gather(*tasks, return_exceptions=True)
41
+ texts = [f"Source URL: {meta['link']}\nContent: {body.strip()}\n" for meta, response in zip(results, responses) if not isinstance(response, Exception) and (body := trafilatura.extract(response.text))]
42
+ if not texts: return "Error: Found web results, but could not extract content."
43
+ return "\n---\n".join(texts)
44
+ except Exception as e: return f"Error during web search: {str(e)}"
45
+
46
+ async def call_gemini(prompt: str, gemini_key: str, model_name: str, json_mode: bool = False) -> str:
47
+ if not gemini_key: return json.dumps({"error": "Gemini API Key is missing."})
48
+ try:
49
+ genai.configure(api_key=gemini_key)
50
+ model = genai.GenerativeModel(model_name)
51
+ generation_config = {"response_mime_type": "application/json"} if json_mode else None
52
+ response = await model.generate_content_async(prompt, generation_config=generation_config)
53
+ return response.text
54
+ except Exception as e: return json.dumps({"error": f"Error calling Gemini: {str(e)}"})
55
+
56
+ # --- AI Agent Logic (Non-Streaming) ---
57
+ async def ultradeep_research_agent(query: str, serper_api_key: str, gemini_key: str, model_name: str) -> str:
58
+ # Step 1: Plan
59
+ planner_prompt = PROMPT_ULTRADEEP_PLANNER.format(query=query)
60
+ plan_str = await call_gemini(planner_prompt, gemini_key, model_name, json_mode=True)
61
+ try:
62
+ match = re.search(r'\{.*\}', plan_str, re.DOTALL)
63
+ if not match: raise ValueError("No JSON object found in Gemini's planner response.")
64
+ plan_data = json.loads(match.group(0))
65
+ if "error" in plan_data: return f"Error during planning phase: {plan_data['error']}"
66
+ search_queries = plan_data["queries"]
67
+ except Exception as e: return f"Error: Could not create a valid research plan. Details: {str(e)}\nRaw Response: {plan_str}"
68
 
69
+ # Step 2: Execute
70
+ tasks = [search_web_logic(sub_query, serper_api_key) for sub_query in search_queries]
71
+ search_results = await asyncio.gather(*tasks)
72
+ all_scraped_text = "\n".join([res for res in search_results if not res.startswith("Error:")])
73
+ if not all_scraped_text: return "Error: Could not retrieve any web content for the planned queries. Check Serper key."
74
+
75
+ # Step 3: Synthesize
76
+ current_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
77
+ synthesizer_prompt = PROMPT_ULTRADEEP_SYNTHESIZER.format(query=query, current_date=current_date, context_text=all_scraped_text)
78
+ final_report = await call_gemini(synthesizer_prompt, gemini_key, model_name)
79
+ return final_report
80
+
81
+ # --- The Single, Unified FastAPI Endpoint ---
82
+ class ResearchRequest(BaseModel):
83
+ query: str
84
+ serper_api_key: str
85
+ gemini_api_key: str
86
+ research_mode: str
87
+ gemini_model: str = "gemini-1.5-flash-latest"
88
+
89
+ @app.post("/api/research")
90
+ async def api_research(request: ResearchRequest):
91
+ if request.research_mode == 'ultradeep':
92
+ result = await ultradeep_research_agent(
93
+ request.query, request.serper_api_key, request.gemini_api_key, request.gemini_model
94
+ )
95
+ else: # Normal and Deep modes
96
+ scraped_text = await search_web_logic(request.query, request.serper_api_key)
97
+ if scraped_text.startswith("Error:"):
98
+ return {"result": scraped_text}
99
+
100
+ current_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
101
+ prompt_template = PROMPT_DEEP if request.research_mode == "deep" else PROMPT_NORMAL
102
+ final_prompt = prompt_template.format(query=request.query, context_text=scraped_text, current_date=current_date)
103
+ result = await call_gemini(final_prompt, request.gemini_api_key, request.gemini_model)
104
+
105
+ return {"result": result}
106
 
107
  # --- Server Startup ---
108
  if __name__ == "__main__":