Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import re | |
| import datetime | |
| import io | |
| from typing import Optional, List, Dict, Any, Tuple | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| import google.generativeai as genai | |
| from google.generativeai import caching | |
| from dotenv import load_dotenv | |
| import base64 | |
| from groq import Groq | |
| from groq import Groq | |
| from mistralai import Mistral | |
| from openai import OpenAI | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| from rlm import RLMEngine | |
| from rlm_extraction import RLMExtractionEngine | |
| # EasyOCR will be imported lazily when needed to avoid startup issues with torch dependencies | |
| load_dotenv() | |
| app = FastAPI() | |
| # Configure max upload size (50MB) | |
| # Set up CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # For development, allow all. stricter in prod | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| max_age=3600, | |
| ) | |
| # Deployment Helpers - Path to frontend static files | |
| FRONTEND_DIST_PATH = os.path.join(os.path.dirname(__file__), "static") | |
| # Set max body size to 50MB | |
| from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
| MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB | |
| # RLM Extraction Thresholds | |
| # Documents exceeding these thresholds will use RLM for extraction | |
| RLM_PAGE_THRESHOLD = 300 # Use RLM for documents with more than 300 pages | |
| RLM_TOKEN_THRESHOLD = 200000 # Or more than 100k estimated tokens | |
| # Global storage for current session | |
| CURRENT_FILE_CONTENT = None | |
| CURRENT_MIME_TYPE = None | |
| CURRENT_CACHE_NAME = None | |
| CURRENT_MODEL_PROVIDER = None | |
| CURRENT_MODEL_NAME = None | |
| # Initialize OCR reader (lazy loading) | |
| OCR_READER = None | |
| # Rate limiting for API calls | |
| from collections import deque | |
| import time as time_module | |
| class APIRateLimiter: | |
| """Global rate limiter for API calls.""" | |
| LIMITS = { | |
| "gemini": {"rpm": 19, "min_interval": 3.5}, # ~17 requests/min with buffer | |
| "groq": {"rpm": 29, "min_interval": 2.1}, | |
| "mistral": {"rpm": 59, "min_interval": 1.0}, | |
| "openrouter": {"rpm": 50, "min_interval": 1.2}, | |
| } | |
| def __init__(self): | |
| self.request_times = { | |
| "gemini": deque(maxlen=19), | |
| "groq": deque(maxlen=29), | |
| "mistral": deque(maxlen=59), | |
| "openrouter": deque(maxlen=50), | |
| } | |
| def wait_if_needed(self, provider: str): | |
| """Wait if we're approaching rate limits for the given provider.""" | |
| provider = provider.lower() | |
| if provider not in self.LIMITS: | |
| return | |
| limits = self.LIMITS[provider] | |
| request_times = self.request_times[provider] | |
| now = time_module.time() | |
| # Check minimum interval between requests | |
| if request_times: | |
| time_since_last = now - request_times[-1] | |
| if time_since_last < limits["min_interval"]: | |
| wait_time = limits["min_interval"] - time_since_last | |
| print(f"[Rate Limiter] Waiting {wait_time:.2f}s before {provider} API call...") | |
| time_module.sleep(wait_time) | |
| # Check if we've hit the per-minute limit | |
| if len(request_times) >= limits["rpm"]: | |
| oldest = request_times[0] | |
| elapsed = now - oldest | |
| if elapsed < 60: | |
| wait_time = 60 - elapsed + 1.0 # Add 1s buffer | |
| print(f"[Rate Limiter] Hit {limits['rpm']} RPM limit for {provider}. Waiting {wait_time:.2f}s...") | |
| time_module.sleep(wait_time) | |
| # Record this request | |
| request_times.append(time_module.time()) | |
| # Global rate limiter instance | |
| RATE_LIMITER = APIRateLimiter() | |
| # Types | |
| class TableData(BaseModel): | |
| title: str | |
| rows: List[Dict[str, Any]] | |
| page_number: Optional[int] = 1 | |
| visualization_config: Optional[Dict[str, Any]] = None | |
| class Metadata(BaseModel): | |
| fileName: str | |
| dateProcessed: str | |
| pageCount: int | |
| fileType: str | |
| class AnalysisResult(BaseModel): | |
| summary: str | |
| metadata: Metadata | |
| tables: List[TableData] | |
| class ChatRequest(BaseModel): | |
| message: str | |
| history: List[Dict[str, str]] = [] | |
| model_name: Optional[str] = "gemini-3-flash-preview" | |
| model_provider: Optional[str] = "gemini" | |
| use_rlm: Optional[bool] = False | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| citations: List[int] = [] | |
| reasoning_trace: Optional[List[Dict[str, Any]]] = None | |
| def clean_json_string(s: str) -> str: | |
| """Removes markdown and extracts JSON object from string.""" | |
| # First, try to just strip markdown tags | |
| cleaned = re.sub(r'```json\n?|\n?```', '', s).strip() | |
| # If it starts with { it's likely good, but if not, search for it | |
| if not cleaned.startswith('{'): | |
| match = re.search(r'\{.*\}', s, re.DOTALL) | |
| if match: | |
| return match.group(0) | |
| return cleaned | |
| def is_scanned_pdf(pdf_bytes: bytes) -> bool: | |
| """Detect if PDF is scanned by checking if it contains extractable text.""" | |
| try: | |
| pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| total_text = "" | |
| for page_num in range(min(3, len(pdf_doc))): # Check first 3 pages | |
| page = pdf_doc[page_num] | |
| total_text += page.get_text() | |
| pdf_doc.close() | |
| # If very little text found, it's likely scanned | |
| return len(total_text.strip()) < 100 | |
| except Exception as e: | |
| print(f"Error checking if PDF is scanned: {e}") | |
| return False | |
| def extract_text_with_ocr(pdf_bytes: bytes) -> str: | |
| """Extract text from scanned PDF using EasyOCR (if available).""" | |
| global OCR_READER | |
| # Try to import easyocr only when needed | |
| try: | |
| if OCR_READER is None: | |
| print("Attempting to load EasyOCR...") | |
| import easyocr | |
| print("Initializing EasyOCR reader...") | |
| OCR_READER = easyocr.Reader(['en']) | |
| print("✓ EasyOCR initialized successfully") | |
| except (ImportError, OSError) as e: | |
| print(f"⚠ EasyOCR not available: {e}") | |
| print(" Falling back to basic text extraction") | |
| return extract_text_from_pdf(pdf_bytes) | |
| try: | |
| pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| extracted_text = [] | |
| for page_num in range(len(pdf_doc)): | |
| page = pdf_doc[page_num] | |
| # Convert page to image | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better OCR | |
| img_bytes = pix.tobytes("png") | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| # Perform OCR | |
| result = OCR_READER.readtext(img_bytes, detail=0) | |
| page_text = "\n".join(result) | |
| extracted_text.append(f"--- Page {page_num + 1} ---\n{page_text}") | |
| pdf_doc.close() | |
| return "\n\n".join(extracted_text) | |
| except Exception as e: | |
| print(f"OCR Error: {e}") | |
| # Fall back to regular extraction if OCR fails | |
| print("Falling back to regular text extraction...") | |
| return extract_text_from_pdf(pdf_bytes) | |
| def extract_text_from_pdf(pdf_bytes: bytes) -> str: | |
| """Extract text from regular PDF.""" | |
| try: | |
| pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| extracted_text = [] | |
| for page_num in range(len(pdf_doc)): | |
| page = pdf_doc[page_num] | |
| text = page.get_text() | |
| extracted_text.append(f"--- Page {page_num + 1} ---\n{text}") | |
| pdf_doc.close() | |
| return "\n\n".join(extracted_text) | |
| except Exception as e: | |
| print(f"Text extraction error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Text extraction failed: {str(e)}") | |
| def analyze_with_gemini(contents: bytes, mime_type: str, model_name: str, system_instruction: str, prompt: str) -> str: | |
| """Analyze document using Gemini models.""" | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise HTTPException(status_code=500, detail="GEMINI_API_KEY is missing") | |
| genai.configure(api_key=api_key) | |
| # Rate limit before token count API call | |
| RATE_LIMITER.wait_if_needed("gemini") | |
| # Check token count to decide on caching | |
| count_model = genai.GenerativeModel(f"models/{model_name}") | |
| token_count = count_model.count_tokens([{'mime_type': mime_type, 'data': contents}]) | |
| print(f"Token count: {token_count.total_tokens}") | |
| # 32k threshold for caching | |
| if token_count.total_tokens > 32768: | |
| print("Document > 32k tokens. Using Context Caching.") | |
| global CURRENT_CACHE_NAME | |
| # Clean up old cache if exists | |
| if CURRENT_CACHE_NAME: | |
| try: | |
| caching.CachedContent.delete(CURRENT_CACHE_NAME) | |
| print(f"Deleted old cache: {CURRENT_CACHE_NAME}") | |
| except Exception as e: | |
| print(f"Error deleting old cache: {e}") | |
| # Create new cache | |
| try: | |
| # Rate limit before cache creation API call | |
| RATE_LIMITER.wait_if_needed("gemini") | |
| cache = caching.CachedContent.create( | |
| model=f'models/{model_name}', | |
| display_name='petromind_doc_cache', | |
| system_instruction=system_instruction, | |
| contents=[{'mime_type': mime_type, 'data': contents}], | |
| ttl=datetime.timedelta(minutes=60) | |
| ) | |
| CURRENT_CACHE_NAME = cache.name | |
| print(f"Created new cache: {CURRENT_CACHE_NAME}") | |
| # Create model from cache | |
| model = genai.GenerativeModel.from_cached_content(cached_content=cache) | |
| # Rate limit before content generation API call | |
| RATE_LIMITER.wait_if_needed("gemini") | |
| response = model.generate_content( | |
| contents=[prompt], | |
| generation_config=genai.types.GenerationConfig( | |
| response_mime_type="application/json" | |
| ), | |
| request_options={'timeout': 600} | |
| ) | |
| return response.text | |
| except Exception as e: | |
| print(f"Caching failed, falling back to standard upload: {e}") | |
| CURRENT_CACHE_NAME = None | |
| # Standard flow (no cache or cache failed) | |
| print("Using standard Gemini flow") | |
| CURRENT_CACHE_NAME = None | |
| model = genai.GenerativeModel( | |
| model_name=f"models/{model_name}", | |
| system_instruction=system_instruction | |
| ) | |
| # Rate limit before content generation API call | |
| RATE_LIMITER.wait_if_needed("gemini") | |
| response = model.generate_content( | |
| contents=[ | |
| {'mime_type': mime_type, 'data': contents}, | |
| prompt | |
| ], | |
| generation_config=genai.types.GenerationConfig( | |
| response_mime_type="application/json" | |
| ), | |
| request_options={'timeout': 600} | |
| ) | |
| return response.text | |
| def analyze_with_groq(text_content: str, model_name: str, system_instruction: str, prompt: str) -> str: | |
| """Analyze document using Groq models.""" | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key or api_key == "your_groq_api_key_here": | |
| raise HTTPException(status_code=500, detail="GROQ_API_KEY is missing or not configured") | |
| client = Groq(api_key=api_key) | |
| full_prompt = f"{system_instruction}\n\n{prompt}\n\nDocument Content:\n{text_content}" | |
| try: | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": full_prompt | |
| } | |
| ], | |
| temperature=0.7, | |
| max_tokens=8000, | |
| response_format={"type": "json_object"} | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| print(f"Groq API error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Groq API error: {str(e)}") | |
| def analyze_with_mistral(text_content: str, model_name: str, system_instruction: str, prompt: str) -> str: | |
| """Analyze document using Mistral AI models.""" | |
| api_key = os.getenv("MISTRAL_API_KEY") | |
| if not api_key or api_key == "your_mistral_api_key_here": | |
| raise HTTPException(status_code=500, detail="MISTRAL_API_KEY is missing or not configured") | |
| client = Mistral(api_key=api_key) | |
| full_prompt = f"{system_instruction}\n\n{prompt}\n\nDocument Content:\n{text_content}" | |
| try: | |
| chat_response = client.chat.complete( | |
| model=model_name, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": full_prompt | |
| } | |
| ], | |
| response_format={"type": "json_object"} | |
| ) | |
| return chat_response.choices[0].message.content | |
| except Exception as e: | |
| print(f"Mistral API error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}") | |
| def analyze_with_openrouter(text_content: str, model_name: str, system_instruction: str, prompt: str) -> str: | |
| """Analyze document using OpenRouter models (via OpenAI client).""" | |
| api_key = os.getenv("OPENROUTER_API_KEY") | |
| if not api_key: | |
| raise HTTPException(status_code=500, detail="OPENROUTER_API_KEY is missing") | |
| client = OpenAI( | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=api_key, | |
| ) | |
| full_prompt = f"{system_instruction}\n\n{prompt}\n\nDocument Content:\n{text_content}" | |
| try: | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": full_prompt | |
| } | |
| ] | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| print(f"OpenRouter API error: {e}") | |
| raise HTTPException(status_code=500, detail=f"OpenRouter API error: {str(e)}") | |
| async def analyze_document( | |
| file: UploadFile = File(...), | |
| model_name: str = Form("gemini-3-flash-preview"), | |
| model_provider: str = Form("gemini"), | |
| use_rlm: str = Form("false") # Receive as string "true"/"false" from FormData | |
| ): | |
| global CURRENT_FILE_CONTENT, CURRENT_MIME_TYPE, CURRENT_CACHE_NAME | |
| global CURRENT_MODEL_PROVIDER, CURRENT_MODEL_NAME | |
| try: | |
| contents = await file.read() | |
| # Check file size | |
| if len(contents) > MAX_UPLOAD_SIZE: | |
| raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE / (1024*1024):.0f}MB") | |
| CURRENT_FILE_CONTENT = contents | |
| CURRENT_MIME_TYPE = file.content_type or "application/pdf" | |
| CURRENT_MODEL_PROVIDER = model_provider | |
| CURRENT_MODEL_NAME = model_name | |
| # Parse boolean flag | |
| should_use_rlm = use_rlm.lower() == "true" | |
| # Count pages for logging/metadata | |
| page_count = 1 | |
| if CURRENT_MIME_TYPE == "application/pdf": | |
| try: | |
| pdf_doc = fitz.open(stream=contents, filetype="pdf") | |
| page_count = len(pdf_doc) | |
| pdf_doc.close() | |
| print(f"[Analyze] Document has {page_count} pages") | |
| except Exception as e: | |
| print(f"[Analyze] Could not count pages: {e}") | |
| # Use RLM if user requested it | |
| if should_use_rlm: | |
| print(f"[Analyze] RLM Mode ENABLED by user. Starting RLM extraction for {page_count} pages...") | |
| # Extract text for RLM | |
| is_scanned = is_scanned_pdf(contents) | |
| if is_scanned: | |
| print("[Analyze] Scanned PDF detected. Using OCR...") | |
| text_content = extract_text_with_ocr(contents) | |
| else: | |
| text_content = extract_text_from_pdf(contents) | |
| # Run RLM extraction | |
| rlm_engine = RLMExtractionEngine( | |
| doc_content=text_content, | |
| total_pages=page_count, | |
| model_provider=model_provider, | |
| model_name=model_name | |
| ) | |
| result = rlm_engine.run() | |
| # Update metadata | |
| result["metadata"]["fileName"] = file.filename or "Unknown" | |
| result["metadata"]["pageCount"] = page_count | |
| return result | |
| # Standard extraction for smaller documents | |
| # Enhanced O&G system instruction | |
| system_instruction = """ | |
| You are an expert Senior Data Engineer specializing in Oil & Gas and Petroleum industry data extraction. | |
| Your task is to extract structured data from technical documents (drilling reports, production logs, well headers, geological surveys, completion reports, etc.). | |
| DOCUMENT ANALYSIS APPROACH: | |
| 1. First, identify what TYPE of O&G document this is (Daily Drilling Report, Well Completion, Production Report, etc.) | |
| 2. Based on document type, determine which data categories are RELEVANT | |
| 3. Extract ONLY categories that contain actual data in the document | |
| 4. Be intelligent - don't force extraction of categories that don't exist | |
| CRITICAL INSTRUCTIONS: | |
| 1. You MUST output data into separate, logical tables | |
| 2. You MUST provide a Confidence Score (0.0 to 1.0) for every row | |
| 3. ONLY include tables for data categories that actually exist in the document | |
| 4. Detect and preserve units (ft, m, psi, bar, bbl, etc.) | |
| O&G DATA CATEGORIES TO LOOK FOR: | |
| - Well Header/Identification (Well Name, Operator, Field, Basin, API#, UWI, Coordinates, Spud/TD Dates) | |
| - Casings & Tubulars (Casing Size, Weight, Grade, Depth Set, Cement Details) | |
| - Deviation/Directional Survey (MD, TVD, Inclination, Azimuth, VS, DLS) | |
| - Formation Tops/Geological Markers (Formation Name, Top Depth, Base Depth, Thickness) | |
| - Mud/Drilling Fluids (Mud Weight, Viscosity, PV, YP, Gel Strength, pH, Cl) | |
| - Production Data (Oil Rate, Gas Rate, Water Rate, GOR, WOR, Choke, Tubing/Casing Pressure) | |
| - Well Tests (DST, LOT, FIT, Flow Rates, Pressures, Permeability, Skin) | |
| - Core/Log Analysis (Porosity, Permeability, Saturation, Net Pay) | |
| - BHA/Equipment (Tool Description, Length, OD, ID, Serial Numbers) | |
| - Daily Operations/Activities (Time, Depth, Activity, Remarks) | |
| """ | |
| prompt = """ | |
| Please analyze and extract structured data from this Oil & Gas document. | |
| Structure the response as JSON: | |
| { | |
| "summary": "Brief executive summary describing document type and key findings", | |
| "metadata": { | |
| "fileName": "derived from document or context", | |
| "dateProcessed": "today's date", | |
| "pageCount": 0, | |
| "fileType": "PDF/Image" | |
| }, | |
| "tables": [ | |
| { | |
| "title": "Category Name", | |
| "rows": [ | |
| { "Column1": "value", "Column2": 123, "__confidence": 0.95 } | |
| ], | |
| "page_number": 1, | |
| "visualization_config": null or { "type": "line_chart/bar_chart/scatter_chart", "xAxisKey": "...", "yAxisKeys": [...], "title": "..." } | |
| } | |
| ] | |
| } | |
| EXTRACTION RULES: | |
| 1. Create a separate table for each data category FOUND in the document | |
| 2. Do NOT create empty tables or tables for categories not in the document | |
| 3. Use consistent column naming (PascalCase or snake_case with units: Depth_ft, Pressure_psi) | |
| 4. Include "__confidence" (0.0-1.0) in EVERY row based on data clarity | |
| 5. Include "page_number" for each table (1-indexed) | |
| 6. For time-series or correlation data, add "visualization_config" with: | |
| - "type": "line_chart" (time-series), "bar_chart" (comparison), "scatter_chart" (correlation) | |
| - "xAxisKey": column for X axis | |
| - "yAxisKeys": columns for Y axis | |
| - "title": descriptive chart title | |
| 7. For non-visualizable data, set "visualization_config" to null | |
| """ | |
| response_text = None | |
| # Route to appropriate provider | |
| if model_provider == "gemini": | |
| response_text = analyze_with_gemini(contents, CURRENT_MIME_TYPE, model_name, system_instruction, prompt) | |
| elif model_provider == "groq": | |
| # Groq doesn't support PDFs directly, extract text first | |
| is_scanned = is_scanned_pdf(contents) | |
| if is_scanned: | |
| print("Scanned PDF detected. Using EasyOCR...") | |
| text_content = extract_text_with_ocr(contents) | |
| else: | |
| print("Regular PDF detected. Extracting text...") | |
| text_content = extract_text_from_pdf(contents) | |
| response_text = analyze_with_groq(text_content, model_name, system_instruction, prompt) | |
| elif model_provider == "mistral": | |
| # Mistral also needs text extraction | |
| is_scanned = is_scanned_pdf(contents) | |
| if is_scanned: | |
| print("Scanned PDF detected. Using EasyOCR...") | |
| text_content = extract_text_with_ocr(contents) | |
| else: | |
| print("Regular PDF detected. Extracting text...") | |
| text_content = extract_text_from_pdf(contents) | |
| response_text = analyze_with_mistral(text_content, model_name, system_instruction, prompt) | |
| elif model_provider == "openrouter": | |
| # OpenRouter needs text extraction | |
| is_scanned = is_scanned_pdf(contents) | |
| if is_scanned: | |
| print("Scanned PDF detected. Using EasyOCR...") | |
| text_content = extract_text_with_ocr(contents) | |
| else: | |
| print("Regular PDF detected. Extracting text...") | |
| text_content = extract_text_from_pdf(contents) | |
| response_text = analyze_with_openrouter(text_content, model_name, system_instruction, prompt) | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unsupported model provider: {model_provider}") | |
| if not response_text: | |
| raise HTTPException(status_code=500, detail="No response text received from AI model.") | |
| cleaned_text = clean_json_string(response_text) | |
| try: | |
| parsed = json.loads(cleaned_text) | |
| except json.JSONDecodeError as e: | |
| print(f"[Analyze] JSON Parse Error: {e}") | |
| print(f"[Analyze] Failed Text Check: {cleaned_text[:500]}...") | |
| # Return a fallback valid response so the UI doesn't crash | |
| return { | |
| "summary": f"**Analysis Error**: The AI model generated invalid JSON data.\n\nError details: {str(e)}\n\nThis often happens with smaller or free models. Please try again or switch models.", | |
| "metadata": { | |
| "fileName": file.filename or "Unknown", | |
| "dateProcessed": "Error", | |
| "pageCount": 0, | |
| "fileType": CURRENT_MIME_TYPE | |
| }, | |
| "tables": [] | |
| } | |
| # Helper to ensure parsed is a dict with tables | |
| if isinstance(parsed, list): | |
| if len(parsed) == 1 and isinstance(parsed[0], dict) and "tables" in parsed[0]: | |
| parsed = parsed[0] | |
| elif parsed and isinstance(parsed[0], dict) and "rows" in parsed[0]: | |
| parsed = {"tables": parsed} | |
| else: | |
| parsed = {"tables": [{"title": "Extracted Data", "rows": parsed}]} | |
| if "tables" not in parsed: | |
| if "data" in parsed and isinstance(parsed["data"], list): | |
| parsed["tables"] = [{"title": "Extracted Data", "rows": parsed["data"]}] | |
| else: | |
| parsed["tables"] = [] | |
| # Add using_rlm flag for frontend | |
| parsed["using_rlm"] = False | |
| return parsed | |
| except Exception as e: | |
| print(f"Error analyzing document: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chat_document(request: ChatRequest) -> ChatResponse: | |
| print(f"DEBUG: Chat Request received. use_rlm={request.use_rlm}, model={request.model_name}") | |
| global CURRENT_FILE_CONTENT, CURRENT_MIME_TYPE, CURRENT_CACHE_NAME | |
| global CURRENT_MODEL_PROVIDER, CURRENT_MODEL_NAME | |
| if not CURRENT_FILE_CONTENT: | |
| raise HTTPException(status_code=400, detail="No document loaded. Please upload a document first.") | |
| try: | |
| model_provider = request.model_provider | |
| model_name = request.model_name | |
| system_instruction = """ | |
| You are a helpful assistant answering questions about the provided Oil & Gas document. | |
| 1. Answer strictly based on the document content. If the answer is not in the document, say so. | |
| 2. Where possible, cite the page number where the information is found using the format [Page X]. | |
| Example: "The well depth is 1250m [Page 3]." | |
| 3. Be concise and accurate. | |
| 4. Do not hallucinate. | |
| """ | |
| # Check for RLM request | |
| if request.use_rlm: | |
| print("Using RLM for deep research...") | |
| # Get text content for RLM | |
| is_scanned = is_scanned_pdf(CURRENT_FILE_CONTENT) | |
| if is_scanned: | |
| text_content = extract_text_with_ocr(CURRENT_FILE_CONTENT) | |
| else: | |
| text_content = extract_text_from_pdf(CURRENT_FILE_CONTENT) | |
| rlm_engine = RLMEngine( | |
| doc_content=text_content, | |
| model_provider=model_provider, | |
| model_name=model_name | |
| ) | |
| rlm_result = rlm_engine.run(request.message) | |
| return ChatResponse( | |
| answer=rlm_result["answer"], | |
| citations=[], | |
| reasoning_trace=rlm_result["reasoning_trace"] | |
| ) | |
| if model_provider == "gemini": | |
| # Check if cache is available | |
| if CURRENT_CACHE_NAME and CURRENT_MODEL_PROVIDER == "gemini": | |
| try: | |
| print(f"Using existing cache: {CURRENT_CACHE_NAME}") | |
| cache = caching.CachedContent.get(CURRENT_CACHE_NAME) | |
| model = genai.GenerativeModel.from_cached_content(cached_content=cache) | |
| # Convert history to Gemini format | |
| gemini_history = [] | |
| for msg in request.history: | |
| role = "user" if msg.get("role") == "user" else "model" | |
| gemini_history.append({ | |
| "role": role, | |
| "parts": [msg.get("content", "")] | |
| }) | |
| chat_session = model.start_chat(history=gemini_history) | |
| # Rate limit before chat API call | |
| RATE_LIMITER.wait_if_needed("gemini") | |
| response = chat_session.send_message(request.message, request_options={'timeout': 600}) | |
| return ChatResponse( | |
| answer=response.text, | |
| citations=[] | |
| ) | |
| except Exception as e: | |
| print(f"Error using cache in chat, falling back: {e}") | |
| CURRENT_CACHE_NAME = None | |
| # Standard Gemini flow | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise HTTPException(status_code=500, detail="GEMINI_API_KEY is missing") | |
| genai.configure(api_key=api_key) | |
| model = genai.GenerativeModel( | |
| model_name=f"models/{model_name}", | |
| system_instruction=system_instruction | |
| ) | |
| # Convert history to Gemini format | |
| gemini_history = [] | |
| for msg in request.history: | |
| role = "user" if msg.get("role") == "user" else "model" | |
| gemini_history.append({ | |
| "role": role, | |
| "parts": [msg.get("content", "")] | |
| }) | |
| chat_session = model.start_chat(history=gemini_history) | |
| # Rate limit before chat API call | |
| RATE_LIMITER.wait_if_needed("gemini") | |
| response = chat_session.send_message([ | |
| {'mime_type': CURRENT_MIME_TYPE, 'data': CURRENT_FILE_CONTENT}, | |
| request.message | |
| ], request_options={'timeout': 600}) | |
| return ChatResponse( | |
| answer=response.text, | |
| citations=[] | |
| ) | |
| elif model_provider == "groq": | |
| # Extract text for Groq | |
| is_scanned = is_scanned_pdf(CURRENT_FILE_CONTENT) | |
| if is_scanned: | |
| text_content = extract_text_with_ocr(CURRENT_FILE_CONTENT) | |
| else: | |
| text_content = extract_text_from_pdf(CURRENT_FILE_CONTENT) | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key or api_key == "your_groq_api_key_here": | |
| raise HTTPException(status_code=500, detail="GROQ_API_KEY is missing or not configured") | |
| client = Groq(api_key=api_key) | |
| # Build messages array with history | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": f"{system_instruction}\n\nDocument Content:\n{text_content}" | |
| } | |
| ] | |
| # Add conversation history | |
| for msg in request.history: | |
| role = msg.get("role", "user") | |
| messages.append({ | |
| "role": role, | |
| "content": msg.get("content", "") | |
| }) | |
| # Add current user message | |
| messages.append({ | |
| "role": "user", | |
| "content": request.message | |
| }) | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=messages, | |
| temperature=0.7, | |
| max_tokens=4096 | |
| ) | |
| return ChatResponse( | |
| answer=completion.choices[0].message.content, | |
| citations=[] | |
| ) | |
| elif model_provider == "mistral": | |
| # Extract text for Mistral | |
| is_scanned = is_scanned_pdf(CURRENT_FILE_CONTENT) | |
| if is_scanned: | |
| text_content = extract_text_with_ocr(CURRENT_FILE_CONTENT) | |
| else: | |
| text_content = extract_text_from_pdf(CURRENT_FILE_CONTENT) | |
| api_key = os.getenv("MISTRAL_API_KEY") | |
| if not api_key or api_key == "your_mistral_api_key_here": | |
| raise HTTPException(status_code=500, detail="MISTRAL_API_KEY is missing or not configured") | |
| client = Mistral(api_key=api_key) | |
| # Build messages array with history | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": f"{system_instruction}\n\nDocument Content:\n{text_content}" | |
| } | |
| ] | |
| # Add conversation history | |
| for msg in request.history: | |
| role = msg.get("role", "user") | |
| messages.append({ | |
| "role": role, | |
| "content": msg.get("content", "") | |
| }) | |
| # Add current user message | |
| messages.append({ | |
| "role": "user", | |
| "content": request.message | |
| }) | |
| chat_response = client.chat.complete( | |
| model=model_name, | |
| messages=messages, | |
| temperature=0.7, | |
| max_tokens=4096 | |
| ) | |
| return ChatResponse( | |
| answer=chat_response.choices[0].message.content, | |
| citations=[] | |
| ) | |
| elif model_provider == "openrouter": | |
| # Extract text for OpenRouter | |
| is_scanned = is_scanned_pdf(CURRENT_FILE_CONTENT) | |
| if is_scanned: | |
| text_content = extract_text_with_ocr(CURRENT_FILE_CONTENT) | |
| else: | |
| text_content = extract_text_from_pdf(CURRENT_FILE_CONTENT) | |
| api_key = os.getenv("OPENROUTER_API_KEY") | |
| if not api_key: | |
| raise HTTPException(status_code=500, detail="OPENROUTER_API_KEY is missing") | |
| client = OpenAI( | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=api_key, | |
| ) | |
| # Build messages array with history | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": f"{system_instruction}\n\nDocument Content:\n{text_content}" | |
| } | |
| ] | |
| # Add conversation history | |
| for msg in request.history: | |
| role = msg.get("role", "user") | |
| messages.append({ | |
| "role": role, | |
| "content": msg.get("content", "") | |
| }) | |
| # Add current user message | |
| messages.append({ | |
| "role": "user", | |
| "content": request.message | |
| }) | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=messages, | |
| ) | |
| return ChatResponse( | |
| answer=completion.choices[0].message.content, | |
| citations=[] | |
| ) | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unsupported model provider: {model_provider}") | |
| except Exception as e: | |
| print(f"Chat error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Mount frontend static files | |
| if os.path.exists(FRONTEND_DIST_PATH): | |
| # Mount assets (js, css, etc.) | |
| app.mount("/assets", StaticFiles(directory=os.path.join(FRONTEND_DIST_PATH, "assets")), name="assets") | |
| # Catch-all route for SPA | |
| async def serve_spa(full_path: str): | |
| # API routes are already handled above because they are defined first | |
| # Check if file exists in static folder (e.g. favicon.ico, logo.png) | |
| possible_file = os.path.join(FRONTEND_DIST_PATH, full_path) | |
| if os.path.isfile(possible_file): | |
| return FileResponse(possible_file) | |
| # Otherwise serve index.html | |
| return FileResponse(os.path.join(FRONTEND_DIST_PATH, "index.html")) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |