| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel, validator, Field |
| | from typing import List, Dict, Any, Union |
| | import google.generativeai as genai |
| | import os |
| | from dotenv import load_dotenv |
| | import logging |
| | import time |
| |
|
| | load_dotenv() |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = FastAPI(title="Language Agent (Gemini Pro - Generalized)") |
| |
|
| | GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
| | GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") |
| |
|
| | if not GOOGLE_API_KEY: |
| | logger.warning("GOOGLE_API_KEY not found.") |
| | else: |
| | try: |
| | genai.configure(api_key=GOOGLE_API_KEY) |
| | logger.info(f"Google Generative AI configured for model {GEMINI_MODEL_NAME}.") |
| | except Exception as e: |
| | logger.error(f"Failed to configure Google Generative AI: {e}") |
| |
|
| |
|
| | class EarningsSummaryLLM(BaseModel): |
| | ticker: str |
| | surprise_pct: float |
| |
|
| |
|
| | class AnalysisDataLLM(BaseModel): |
| | target_label: str = "the portfolio" |
| | current_allocation: float = 0.0 |
| | yesterday_allocation: float = 0.0 |
| | allocation_change_percentage_points: float = 0.0 |
| |
|
| | earnings_surprises: List[EarningsSummaryLLM] = Field( |
| | default_factory=list, alias="earnings_surprises_for_target" |
| | ) |
| |
|
| |
|
| | class BriefRequest(BaseModel): |
| | user_query: str |
| | analysis: AnalysisDataLLM |
| | retrieved_docs: List[str] = Field(default_factory=list) |
| |
|
| |
|
| | def construct_gemini_prompt( |
| | user_query: str, analysis_data: AnalysisDataLLM, docs_context: str |
| | ) -> str: |
| |
|
| | alloc_change_str = "" |
| | if analysis_data.allocation_change_percentage_points > 0.01: |
| | alloc_change_str = f"up by {analysis_data.allocation_change_percentage_points:.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)." |
| | elif analysis_data.allocation_change_percentage_points < -0.01: |
| | alloc_change_str = f"down by {abs(analysis_data.allocation_change_percentage_points):.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)." |
| | else: |
| | alloc_change_str = f"remaining stable around {analysis_data.yesterday_allocation*100:.0f}% yesterday." |
| |
|
| | analysis_summary_str = f"For {analysis_data.target_label}, the current allocation is {analysis_data.current_allocation*100:.0f}% of AUM, {alloc_change_str}\n" |
| |
|
| | if analysis_data.earnings_surprises: |
| | earnings_parts = [] |
| | for e in analysis_data.earnings_surprises: |
| | direction = ( |
| | "beat estimates by" if e.surprise_pct >= 0 else "missed estimates by" |
| | ) |
| | earnings_parts.append(f"{e.ticker} {direction} {abs(e.surprise_pct):.1f}%") |
| | if earnings_parts: |
| | analysis_summary_str += ( |
| | "Key earnings updates: " + ", ".join(earnings_parts) + "." |
| | ) |
| | else: |
| | analysis_summary_str += ( |
| | "No specific earnings surprises to highlight for this segment." |
| | ) |
| | else: |
| | analysis_summary_str += ( |
| | "No notable earnings surprises reported for this segment." |
| | ) |
| |
|
| | prompt = ( |
| | f"You are a professional financial assistant. Based on the user's query and the provided data, " |
| | f"deliver a concise, spoken-style morning market brief for a portfolio manager. " |
| | f"The brief should start with 'Good morning.'\n\n" |
| | f"User Query: {user_query}\n\n" |
| | f"Key Portfolio and Market Analysis:\n{analysis_summary_str}\n\n" |
| | f"Relevant Filings Context (if any):\n{docs_context}\n\n" |
| | f"If the user's query mentions a specific region or sector not covered by the 'Key Portfolio and Market Analysis', " |
| | f"you can state that specific data for that exact query aspect was not available in the analysis provided. " |
| | f"Mention any specific company earnings surprises from the analysis clearly (e.g., 'TSMC beat estimates by X%, Samsung missed by Y%')." |
| | f"If there's information about broad regional sentiment or rising yields in the 'docs_context', incorporate it naturally. Otherwise, focus on the provided analysis." |
| | ) |
| | return prompt |
| |
|
| |
|
| | generation_config = genai.types.GenerationConfig( |
| | temperature=0.6, max_output_tokens=1024 |
| | ) |
| | safety_settings = [ |
| | {"category": c, "threshold": "BLOCK_MEDIUM_AND_ABOVE"} |
| | for c in [ |
| | "HARM_CATEGORY_HARASSMENT", |
| | "HARM_CATEGORY_HATE_SPEECH", |
| | "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
| | "HARM_CATEGORY_DANGEROUS_CONTENT", |
| | ] |
| | ] |
| |
|
| |
|
| | @app.post("/generate_brief") |
| | async def generate_brief(request: BriefRequest): |
| | if not GOOGLE_API_KEY: |
| | raise HTTPException(status_code=500, detail="Google API Key not configured.") |
| | logger.info( |
| | f"Generating brief for query: '{request.user_query}' using Gemini model {GEMINI_MODEL_NAME}" |
| | ) |
| |
|
| | docs_context = ( |
| | "\n".join(request.retrieved_docs[:2]) |
| | if request.retrieved_docs |
| | else "No relevant context from documents found." |
| | ) |
| |
|
| | full_prompt = construct_gemini_prompt( |
| | user_query=request.user_query, |
| | analysis_data=request.analysis, |
| | docs_context=docs_context, |
| | ) |
| | logger.debug(f"Full prompt for Gemini:\n{full_prompt}") |
| |
|
| | try: |
| | model = genai.GenerativeModel( |
| | model_name=GEMINI_MODEL_NAME, |
| | generation_config=generation_config, |
| | safety_settings=safety_settings, |
| | ) |
| | max_retries = 1 |
| | retry_delay_seconds = 10 |
| | for attempt in range(max_retries + 1): |
| | try: |
| | response = await model.generate_content_async(full_prompt) |
| |
|
| | if not response.parts: |
| | if ( |
| | response.prompt_feedback |
| | and response.prompt_feedback.block_reason |
| | ): |
| | block_reason_message = ( |
| | response.prompt_feedback.block_reason_message |
| | or "Unknown safety block" |
| | ) |
| | logger.error( |
| | f"Gemini content generation blocked. Reason: {block_reason_message}" |
| | ) |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"Content generation blocked: {block_reason_message}", |
| | ) |
| | else: |
| | logger.error("Gemini response has no parts (empty content).") |
| |
|
| | if attempt == max_retries: |
| | raise HTTPException( |
| | status_code=500, |
| | detail="Gemini returned empty content after retries.", |
| | ) |
| | else: |
| | logger.warning( |
| | f"Gemini returned empty content, attempt {attempt+1}/{max_retries+1}. Retrying..." |
| | ) |
| | await asyncio.sleep(retry_delay_seconds) |
| | continue |
| |
|
| | brief_text = response.text |
| | logger.info("Gemini content generated successfully.") |
| | return {"brief": brief_text} |
| |
|
| | except ( |
| | genai.types.generation_types.BlockedPromptException, |
| | genai.types.generation_types.StopCandidateException, |
| | ) as sce_bpe: |
| | logger.error( |
| | f"Gemini generation issue on attempt {attempt+1}: {sce_bpe}" |
| | ) |
| | raise HTTPException( |
| | status_code=400, detail=f"Gemini generation issue: {sce_bpe}" |
| | ) |
| | except Exception as e: |
| | logger.error( |
| | f"Error during Gemini generation on attempt {attempt+1}: {type(e).__name__} - {e}" |
| | ) |
| | if ( |
| | "rate limit" in str(e).lower() |
| | or "quota" in str(e).lower() |
| | or "429" in str(e) |
| | or "resource_exhausted" in str(e).lower() |
| | ): |
| | if attempt < max_retries: |
| | wait_time = retry_delay_seconds * (2**attempt) |
| | logger.info(f"Rate limit likely. Retrying in {wait_time}s...") |
| | await asyncio.sleep(wait_time) |
| | continue |
| | else: |
| | logger.error("Max retries reached for rate limit.") |
| | raise HTTPException( |
| | status_code=429, |
| | detail=f"Gemini API rate limit/quota exceeded: {e}", |
| | ) |
| | elif attempt < max_retries: |
| | await asyncio.sleep(retry_delay_seconds) |
| | continue |
| | else: |
| | raise HTTPException( |
| | status_code=500, |
| | detail=f"Failed to generate brief with Gemini: {e}", |
| | ) |
| |
|
| | raise HTTPException( |
| | status_code=500, detail="Brief generation failed after all attempts." |
| | ) |
| | except HTTPException: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Critical error in /generate_brief: {e}", exc_info=True) |
| | raise HTTPException( |
| | status_code=500, detail=f"Critical failure in brief generation: {e}" |
| | ) |
| |
|