PetroMind_AI / backend /main.py
gauthamnairy's picture
Upload 41 files
609c821 verified
import os
import json
import re
import datetime
import io
from typing import Optional, List, Dict, Any, Tuple
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
import google.generativeai as genai
from google.generativeai import caching
from dotenv import load_dotenv
import base64
from groq import Groq
from groq import Groq
from mistralai import Mistral
from openai import OpenAI
import fitz # PyMuPDF
from PIL import Image
from rlm import RLMEngine
from rlm_extraction import RLMExtractionEngine
# EasyOCR will be imported lazily when needed to avoid startup issues with torch dependencies
load_dotenv()
app = FastAPI()
# Configure max upload size (50MB)
# Set up CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development, allow all. stricter in prod
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
max_age=3600,
)
# Deployment Helpers - Path to frontend static files
FRONTEND_DIST_PATH = os.path.join(os.path.dirname(__file__), "static")
# Set max body size to 50MB
from fastapi.middleware.trustedhost import TrustedHostMiddleware
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
# RLM Extraction Thresholds
# Documents exceeding these thresholds will use RLM for extraction
RLM_PAGE_THRESHOLD = 300 # Use RLM for documents with more than 300 pages
RLM_TOKEN_THRESHOLD = 200000 # Or more than 100k estimated tokens
# Global storage for current session
CURRENT_FILE_CONTENT = None
CURRENT_MIME_TYPE = None
CURRENT_CACHE_NAME = None
CURRENT_MODEL_PROVIDER = None
CURRENT_MODEL_NAME = None
# Initialize OCR reader (lazy loading)
OCR_READER = None
# Rate limiting for API calls
from collections import deque
import time as time_module
class APIRateLimiter:
"""Global rate limiter for API calls."""
LIMITS = {
"gemini": {"rpm": 19, "min_interval": 3.5}, # ~17 requests/min with buffer
"groq": {"rpm": 29, "min_interval": 2.1},
"mistral": {"rpm": 59, "min_interval": 1.0},
"openrouter": {"rpm": 50, "min_interval": 1.2},
}
def __init__(self):
self.request_times = {
"gemini": deque(maxlen=19),
"groq": deque(maxlen=29),
"mistral": deque(maxlen=59),
"openrouter": deque(maxlen=50),
}
def wait_if_needed(self, provider: str):
"""Wait if we're approaching rate limits for the given provider."""
provider = provider.lower()
if provider not in self.LIMITS:
return
limits = self.LIMITS[provider]
request_times = self.request_times[provider]
now = time_module.time()
# Check minimum interval between requests
if request_times:
time_since_last = now - request_times[-1]
if time_since_last < limits["min_interval"]:
wait_time = limits["min_interval"] - time_since_last
print(f"[Rate Limiter] Waiting {wait_time:.2f}s before {provider} API call...")
time_module.sleep(wait_time)
# Check if we've hit the per-minute limit
if len(request_times) >= limits["rpm"]:
oldest = request_times[0]
elapsed = now - oldest
if elapsed < 60:
wait_time = 60 - elapsed + 1.0 # Add 1s buffer
print(f"[Rate Limiter] Hit {limits['rpm']} RPM limit for {provider}. Waiting {wait_time:.2f}s...")
time_module.sleep(wait_time)
# Record this request
request_times.append(time_module.time())
# Global rate limiter instance
RATE_LIMITER = APIRateLimiter()
# Types
class TableData(BaseModel):
title: str
rows: List[Dict[str, Any]]
page_number: Optional[int] = 1
visualization_config: Optional[Dict[str, Any]] = None
class Metadata(BaseModel):
fileName: str
dateProcessed: str
pageCount: int
fileType: str
class AnalysisResult(BaseModel):
summary: str
metadata: Metadata
tables: List[TableData]
class ChatRequest(BaseModel):
message: str
history: List[Dict[str, str]] = []
model_name: Optional[str] = "gemini-3-flash-preview"
model_provider: Optional[str] = "gemini"
use_rlm: Optional[bool] = False
class ChatResponse(BaseModel):
answer: str
class ChatResponse(BaseModel):
answer: str
citations: List[int] = []
reasoning_trace: Optional[List[Dict[str, Any]]] = None
def clean_json_string(s: str) -> str:
"""Removes markdown and extracts JSON object from string."""
# First, try to just strip markdown tags
cleaned = re.sub(r'```json\n?|\n?```', '', s).strip()
# If it starts with { it's likely good, but if not, search for it
if not cleaned.startswith('{'):
match = re.search(r'\{.*\}', s, re.DOTALL)
if match:
return match.group(0)
return cleaned
def is_scanned_pdf(pdf_bytes: bytes) -> bool:
"""Detect if PDF is scanned by checking if it contains extractable text."""
try:
pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
total_text = ""
for page_num in range(min(3, len(pdf_doc))): # Check first 3 pages
page = pdf_doc[page_num]
total_text += page.get_text()
pdf_doc.close()
# If very little text found, it's likely scanned
return len(total_text.strip()) < 100
except Exception as e:
print(f"Error checking if PDF is scanned: {e}")
return False
def extract_text_with_ocr(pdf_bytes: bytes) -> str:
"""Extract text from scanned PDF using EasyOCR (if available)."""
global OCR_READER
# Try to import easyocr only when needed
try:
if OCR_READER is None:
print("Attempting to load EasyOCR...")
import easyocr
print("Initializing EasyOCR reader...")
OCR_READER = easyocr.Reader(['en'])
print("✓ EasyOCR initialized successfully")
except (ImportError, OSError) as e:
print(f"⚠ EasyOCR not available: {e}")
print(" Falling back to basic text extraction")
return extract_text_from_pdf(pdf_bytes)
try:
pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
extracted_text = []
for page_num in range(len(pdf_doc)):
page = pdf_doc[page_num]
# Convert page to image
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better OCR
img_bytes = pix.tobytes("png")
img = Image.open(io.BytesIO(img_bytes))
# Perform OCR
result = OCR_READER.readtext(img_bytes, detail=0)
page_text = "\n".join(result)
extracted_text.append(f"--- Page {page_num + 1} ---\n{page_text}")
pdf_doc.close()
return "\n\n".join(extracted_text)
except Exception as e:
print(f"OCR Error: {e}")
# Fall back to regular extraction if OCR fails
print("Falling back to regular text extraction...")
return extract_text_from_pdf(pdf_bytes)
def extract_text_from_pdf(pdf_bytes: bytes) -> str:
"""Extract text from regular PDF."""
try:
pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
extracted_text = []
for page_num in range(len(pdf_doc)):
page = pdf_doc[page_num]
text = page.get_text()
extracted_text.append(f"--- Page {page_num + 1} ---\n{text}")
pdf_doc.close()
return "\n\n".join(extracted_text)
except Exception as e:
print(f"Text extraction error: {e}")
raise HTTPException(status_code=500, detail=f"Text extraction failed: {str(e)}")
def analyze_with_gemini(contents: bytes, mime_type: str, model_name: str, system_instruction: str, prompt: str) -> str:
"""Analyze document using Gemini models."""
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="GEMINI_API_KEY is missing")
genai.configure(api_key=api_key)
# Rate limit before token count API call
RATE_LIMITER.wait_if_needed("gemini")
# Check token count to decide on caching
count_model = genai.GenerativeModel(f"models/{model_name}")
token_count = count_model.count_tokens([{'mime_type': mime_type, 'data': contents}])
print(f"Token count: {token_count.total_tokens}")
# 32k threshold for caching
if token_count.total_tokens > 32768:
print("Document > 32k tokens. Using Context Caching.")
global CURRENT_CACHE_NAME
# Clean up old cache if exists
if CURRENT_CACHE_NAME:
try:
caching.CachedContent.delete(CURRENT_CACHE_NAME)
print(f"Deleted old cache: {CURRENT_CACHE_NAME}")
except Exception as e:
print(f"Error deleting old cache: {e}")
# Create new cache
try:
# Rate limit before cache creation API call
RATE_LIMITER.wait_if_needed("gemini")
cache = caching.CachedContent.create(
model=f'models/{model_name}',
display_name='petromind_doc_cache',
system_instruction=system_instruction,
contents=[{'mime_type': mime_type, 'data': contents}],
ttl=datetime.timedelta(minutes=60)
)
CURRENT_CACHE_NAME = cache.name
print(f"Created new cache: {CURRENT_CACHE_NAME}")
# Create model from cache
model = genai.GenerativeModel.from_cached_content(cached_content=cache)
# Rate limit before content generation API call
RATE_LIMITER.wait_if_needed("gemini")
response = model.generate_content(
contents=[prompt],
generation_config=genai.types.GenerationConfig(
response_mime_type="application/json"
),
request_options={'timeout': 600}
)
return response.text
except Exception as e:
print(f"Caching failed, falling back to standard upload: {e}")
CURRENT_CACHE_NAME = None
# Standard flow (no cache or cache failed)
print("Using standard Gemini flow")
CURRENT_CACHE_NAME = None
model = genai.GenerativeModel(
model_name=f"models/{model_name}",
system_instruction=system_instruction
)
# Rate limit before content generation API call
RATE_LIMITER.wait_if_needed("gemini")
response = model.generate_content(
contents=[
{'mime_type': mime_type, 'data': contents},
prompt
],
generation_config=genai.types.GenerationConfig(
response_mime_type="application/json"
),
request_options={'timeout': 600}
)
return response.text
def analyze_with_groq(text_content: str, model_name: str, system_instruction: str, prompt: str) -> str:
"""Analyze document using Groq models."""
api_key = os.getenv("GROQ_API_KEY")
if not api_key or api_key == "your_groq_api_key_here":
raise HTTPException(status_code=500, detail="GROQ_API_KEY is missing or not configured")
client = Groq(api_key=api_key)
full_prompt = f"{system_instruction}\n\n{prompt}\n\nDocument Content:\n{text_content}"
try:
completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": full_prompt
}
],
temperature=0.7,
max_tokens=8000,
response_format={"type": "json_object"}
)
return completion.choices[0].message.content
except Exception as e:
print(f"Groq API error: {e}")
raise HTTPException(status_code=500, detail=f"Groq API error: {str(e)}")
def analyze_with_mistral(text_content: str, model_name: str, system_instruction: str, prompt: str) -> str:
"""Analyze document using Mistral AI models."""
api_key = os.getenv("MISTRAL_API_KEY")
if not api_key or api_key == "your_mistral_api_key_here":
raise HTTPException(status_code=500, detail="MISTRAL_API_KEY is missing or not configured")
client = Mistral(api_key=api_key)
full_prompt = f"{system_instruction}\n\n{prompt}\n\nDocument Content:\n{text_content}"
try:
chat_response = client.chat.complete(
model=model_name,
messages=[
{
"role": "user",
"content": full_prompt
}
],
response_format={"type": "json_object"}
)
return chat_response.choices[0].message.content
except Exception as e:
print(f"Mistral API error: {e}")
raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}")
def analyze_with_openrouter(text_content: str, model_name: str, system_instruction: str, prompt: str) -> str:
"""Analyze document using OpenRouter models (via OpenAI client)."""
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="OPENROUTER_API_KEY is missing")
client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
)
full_prompt = f"{system_instruction}\n\n{prompt}\n\nDocument Content:\n{text_content}"
try:
completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": full_prompt
}
]
)
return completion.choices[0].message.content
except Exception as e:
print(f"OpenRouter API error: {e}")
raise HTTPException(status_code=500, detail=f"OpenRouter API error: {str(e)}")
@app.post("/analyze")
async def analyze_document(
file: UploadFile = File(...),
model_name: str = Form("gemini-3-flash-preview"),
model_provider: str = Form("gemini"),
use_rlm: str = Form("false") # Receive as string "true"/"false" from FormData
):
global CURRENT_FILE_CONTENT, CURRENT_MIME_TYPE, CURRENT_CACHE_NAME
global CURRENT_MODEL_PROVIDER, CURRENT_MODEL_NAME
try:
contents = await file.read()
# Check file size
if len(contents) > MAX_UPLOAD_SIZE:
raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE / (1024*1024):.0f}MB")
CURRENT_FILE_CONTENT = contents
CURRENT_MIME_TYPE = file.content_type or "application/pdf"
CURRENT_MODEL_PROVIDER = model_provider
CURRENT_MODEL_NAME = model_name
# Parse boolean flag
should_use_rlm = use_rlm.lower() == "true"
# Count pages for logging/metadata
page_count = 1
if CURRENT_MIME_TYPE == "application/pdf":
try:
pdf_doc = fitz.open(stream=contents, filetype="pdf")
page_count = len(pdf_doc)
pdf_doc.close()
print(f"[Analyze] Document has {page_count} pages")
except Exception as e:
print(f"[Analyze] Could not count pages: {e}")
# Use RLM if user requested it
if should_use_rlm:
print(f"[Analyze] RLM Mode ENABLED by user. Starting RLM extraction for {page_count} pages...")
# Extract text for RLM
is_scanned = is_scanned_pdf(contents)
if is_scanned:
print("[Analyze] Scanned PDF detected. Using OCR...")
text_content = extract_text_with_ocr(contents)
else:
text_content = extract_text_from_pdf(contents)
# Run RLM extraction
rlm_engine = RLMExtractionEngine(
doc_content=text_content,
total_pages=page_count,
model_provider=model_provider,
model_name=model_name
)
result = rlm_engine.run()
# Update metadata
result["metadata"]["fileName"] = file.filename or "Unknown"
result["metadata"]["pageCount"] = page_count
return result
# Standard extraction for smaller documents
# Enhanced O&G system instruction
system_instruction = """
You are an expert Senior Data Engineer specializing in Oil & Gas and Petroleum industry data extraction.
Your task is to extract structured data from technical documents (drilling reports, production logs, well headers, geological surveys, completion reports, etc.).
DOCUMENT ANALYSIS APPROACH:
1. First, identify what TYPE of O&G document this is (Daily Drilling Report, Well Completion, Production Report, etc.)
2. Based on document type, determine which data categories are RELEVANT
3. Extract ONLY categories that contain actual data in the document
4. Be intelligent - don't force extraction of categories that don't exist
CRITICAL INSTRUCTIONS:
1. You MUST output data into separate, logical tables
2. You MUST provide a Confidence Score (0.0 to 1.0) for every row
3. ONLY include tables for data categories that actually exist in the document
4. Detect and preserve units (ft, m, psi, bar, bbl, etc.)
O&G DATA CATEGORIES TO LOOK FOR:
- Well Header/Identification (Well Name, Operator, Field, Basin, API#, UWI, Coordinates, Spud/TD Dates)
- Casings & Tubulars (Casing Size, Weight, Grade, Depth Set, Cement Details)
- Deviation/Directional Survey (MD, TVD, Inclination, Azimuth, VS, DLS)
- Formation Tops/Geological Markers (Formation Name, Top Depth, Base Depth, Thickness)
- Mud/Drilling Fluids (Mud Weight, Viscosity, PV, YP, Gel Strength, pH, Cl)
- Production Data (Oil Rate, Gas Rate, Water Rate, GOR, WOR, Choke, Tubing/Casing Pressure)
- Well Tests (DST, LOT, FIT, Flow Rates, Pressures, Permeability, Skin)
- Core/Log Analysis (Porosity, Permeability, Saturation, Net Pay)
- BHA/Equipment (Tool Description, Length, OD, ID, Serial Numbers)
- Daily Operations/Activities (Time, Depth, Activity, Remarks)
"""
prompt = """
Please analyze and extract structured data from this Oil & Gas document.
Structure the response as JSON:
{
"summary": "Brief executive summary describing document type and key findings",
"metadata": {
"fileName": "derived from document or context",
"dateProcessed": "today's date",
"pageCount": 0,
"fileType": "PDF/Image"
},
"tables": [
{
"title": "Category Name",
"rows": [
{ "Column1": "value", "Column2": 123, "__confidence": 0.95 }
],
"page_number": 1,
"visualization_config": null or { "type": "line_chart/bar_chart/scatter_chart", "xAxisKey": "...", "yAxisKeys": [...], "title": "..." }
}
]
}
EXTRACTION RULES:
1. Create a separate table for each data category FOUND in the document
2. Do NOT create empty tables or tables for categories not in the document
3. Use consistent column naming (PascalCase or snake_case with units: Depth_ft, Pressure_psi)
4. Include "__confidence" (0.0-1.0) in EVERY row based on data clarity
5. Include "page_number" for each table (1-indexed)
6. For time-series or correlation data, add "visualization_config" with:
- "type": "line_chart" (time-series), "bar_chart" (comparison), "scatter_chart" (correlation)
- "xAxisKey": column for X axis
- "yAxisKeys": columns for Y axis
- "title": descriptive chart title
7. For non-visualizable data, set "visualization_config" to null
"""
response_text = None
# Route to appropriate provider
if model_provider == "gemini":
response_text = analyze_with_gemini(contents, CURRENT_MIME_TYPE, model_name, system_instruction, prompt)
elif model_provider == "groq":
# Groq doesn't support PDFs directly, extract text first
is_scanned = is_scanned_pdf(contents)
if is_scanned:
print("Scanned PDF detected. Using EasyOCR...")
text_content = extract_text_with_ocr(contents)
else:
print("Regular PDF detected. Extracting text...")
text_content = extract_text_from_pdf(contents)
response_text = analyze_with_groq(text_content, model_name, system_instruction, prompt)
elif model_provider == "mistral":
# Mistral also needs text extraction
is_scanned = is_scanned_pdf(contents)
if is_scanned:
print("Scanned PDF detected. Using EasyOCR...")
text_content = extract_text_with_ocr(contents)
else:
print("Regular PDF detected. Extracting text...")
text_content = extract_text_from_pdf(contents)
response_text = analyze_with_mistral(text_content, model_name, system_instruction, prompt)
elif model_provider == "openrouter":
# OpenRouter needs text extraction
is_scanned = is_scanned_pdf(contents)
if is_scanned:
print("Scanned PDF detected. Using EasyOCR...")
text_content = extract_text_with_ocr(contents)
else:
print("Regular PDF detected. Extracting text...")
text_content = extract_text_from_pdf(contents)
response_text = analyze_with_openrouter(text_content, model_name, system_instruction, prompt)
else:
raise HTTPException(status_code=400, detail=f"Unsupported model provider: {model_provider}")
if not response_text:
raise HTTPException(status_code=500, detail="No response text received from AI model.")
cleaned_text = clean_json_string(response_text)
try:
parsed = json.loads(cleaned_text)
except json.JSONDecodeError as e:
print(f"[Analyze] JSON Parse Error: {e}")
print(f"[Analyze] Failed Text Check: {cleaned_text[:500]}...")
# Return a fallback valid response so the UI doesn't crash
return {
"summary": f"**Analysis Error**: The AI model generated invalid JSON data.\n\nError details: {str(e)}\n\nThis often happens with smaller or free models. Please try again or switch models.",
"metadata": {
"fileName": file.filename or "Unknown",
"dateProcessed": "Error",
"pageCount": 0,
"fileType": CURRENT_MIME_TYPE
},
"tables": []
}
# Helper to ensure parsed is a dict with tables
if isinstance(parsed, list):
if len(parsed) == 1 and isinstance(parsed[0], dict) and "tables" in parsed[0]:
parsed = parsed[0]
elif parsed and isinstance(parsed[0], dict) and "rows" in parsed[0]:
parsed = {"tables": parsed}
else:
parsed = {"tables": [{"title": "Extracted Data", "rows": parsed}]}
if "tables" not in parsed:
if "data" in parsed and isinstance(parsed["data"], list):
parsed["tables"] = [{"title": "Extracted Data", "rows": parsed["data"]}]
else:
parsed["tables"] = []
# Add using_rlm flag for frontend
parsed["using_rlm"] = False
return parsed
except Exception as e:
print(f"Error analyzing document: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat")
async def chat_document(request: ChatRequest) -> ChatResponse:
print(f"DEBUG: Chat Request received. use_rlm={request.use_rlm}, model={request.model_name}")
global CURRENT_FILE_CONTENT, CURRENT_MIME_TYPE, CURRENT_CACHE_NAME
global CURRENT_MODEL_PROVIDER, CURRENT_MODEL_NAME
if not CURRENT_FILE_CONTENT:
raise HTTPException(status_code=400, detail="No document loaded. Please upload a document first.")
try:
model_provider = request.model_provider
model_name = request.model_name
system_instruction = """
You are a helpful assistant answering questions about the provided Oil & Gas document.
1. Answer strictly based on the document content. If the answer is not in the document, say so.
2. Where possible, cite the page number where the information is found using the format [Page X].
Example: "The well depth is 1250m [Page 3]."
3. Be concise and accurate.
4. Do not hallucinate.
"""
# Check for RLM request
if request.use_rlm:
print("Using RLM for deep research...")
# Get text content for RLM
is_scanned = is_scanned_pdf(CURRENT_FILE_CONTENT)
if is_scanned:
text_content = extract_text_with_ocr(CURRENT_FILE_CONTENT)
else:
text_content = extract_text_from_pdf(CURRENT_FILE_CONTENT)
rlm_engine = RLMEngine(
doc_content=text_content,
model_provider=model_provider,
model_name=model_name
)
rlm_result = rlm_engine.run(request.message)
return ChatResponse(
answer=rlm_result["answer"],
citations=[],
reasoning_trace=rlm_result["reasoning_trace"]
)
if model_provider == "gemini":
# Check if cache is available
if CURRENT_CACHE_NAME and CURRENT_MODEL_PROVIDER == "gemini":
try:
print(f"Using existing cache: {CURRENT_CACHE_NAME}")
cache = caching.CachedContent.get(CURRENT_CACHE_NAME)
model = genai.GenerativeModel.from_cached_content(cached_content=cache)
# Convert history to Gemini format
gemini_history = []
for msg in request.history:
role = "user" if msg.get("role") == "user" else "model"
gemini_history.append({
"role": role,
"parts": [msg.get("content", "")]
})
chat_session = model.start_chat(history=gemini_history)
# Rate limit before chat API call
RATE_LIMITER.wait_if_needed("gemini")
response = chat_session.send_message(request.message, request_options={'timeout': 600})
return ChatResponse(
answer=response.text,
citations=[]
)
except Exception as e:
print(f"Error using cache in chat, falling back: {e}")
CURRENT_CACHE_NAME = None
# Standard Gemini flow
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="GEMINI_API_KEY is missing")
genai.configure(api_key=api_key)
model = genai.GenerativeModel(
model_name=f"models/{model_name}",
system_instruction=system_instruction
)
# Convert history to Gemini format
gemini_history = []
for msg in request.history:
role = "user" if msg.get("role") == "user" else "model"
gemini_history.append({
"role": role,
"parts": [msg.get("content", "")]
})
chat_session = model.start_chat(history=gemini_history)
# Rate limit before chat API call
RATE_LIMITER.wait_if_needed("gemini")
response = chat_session.send_message([
{'mime_type': CURRENT_MIME_TYPE, 'data': CURRENT_FILE_CONTENT},
request.message
], request_options={'timeout': 600})
return ChatResponse(
answer=response.text,
citations=[]
)
elif model_provider == "groq":
# Extract text for Groq
is_scanned = is_scanned_pdf(CURRENT_FILE_CONTENT)
if is_scanned:
text_content = extract_text_with_ocr(CURRENT_FILE_CONTENT)
else:
text_content = extract_text_from_pdf(CURRENT_FILE_CONTENT)
api_key = os.getenv("GROQ_API_KEY")
if not api_key or api_key == "your_groq_api_key_here":
raise HTTPException(status_code=500, detail="GROQ_API_KEY is missing or not configured")
client = Groq(api_key=api_key)
# Build messages array with history
messages = [
{
"role": "system",
"content": f"{system_instruction}\n\nDocument Content:\n{text_content}"
}
]
# Add conversation history
for msg in request.history:
role = msg.get("role", "user")
messages.append({
"role": role,
"content": msg.get("content", "")
})
# Add current user message
messages.append({
"role": "user",
"content": request.message
})
completion = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.7,
max_tokens=4096
)
return ChatResponse(
answer=completion.choices[0].message.content,
citations=[]
)
elif model_provider == "mistral":
# Extract text for Mistral
is_scanned = is_scanned_pdf(CURRENT_FILE_CONTENT)
if is_scanned:
text_content = extract_text_with_ocr(CURRENT_FILE_CONTENT)
else:
text_content = extract_text_from_pdf(CURRENT_FILE_CONTENT)
api_key = os.getenv("MISTRAL_API_KEY")
if not api_key or api_key == "your_mistral_api_key_here":
raise HTTPException(status_code=500, detail="MISTRAL_API_KEY is missing or not configured")
client = Mistral(api_key=api_key)
# Build messages array with history
messages = [
{
"role": "system",
"content": f"{system_instruction}\n\nDocument Content:\n{text_content}"
}
]
# Add conversation history
for msg in request.history:
role = msg.get("role", "user")
messages.append({
"role": role,
"content": msg.get("content", "")
})
# Add current user message
messages.append({
"role": "user",
"content": request.message
})
chat_response = client.chat.complete(
model=model_name,
messages=messages,
temperature=0.7,
max_tokens=4096
)
return ChatResponse(
answer=chat_response.choices[0].message.content,
citations=[]
)
elif model_provider == "openrouter":
# Extract text for OpenRouter
is_scanned = is_scanned_pdf(CURRENT_FILE_CONTENT)
if is_scanned:
text_content = extract_text_with_ocr(CURRENT_FILE_CONTENT)
else:
text_content = extract_text_from_pdf(CURRENT_FILE_CONTENT)
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="OPENROUTER_API_KEY is missing")
client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
)
# Build messages array with history
messages = [
{
"role": "system",
"content": f"{system_instruction}\n\nDocument Content:\n{text_content}"
}
]
# Add conversation history
for msg in request.history:
role = msg.get("role", "user")
messages.append({
"role": role,
"content": msg.get("content", "")
})
# Add current user message
messages.append({
"role": "user",
"content": request.message
})
completion = client.chat.completions.create(
model=model_name,
messages=messages,
)
return ChatResponse(
answer=completion.choices[0].message.content,
citations=[]
)
else:
raise HTTPException(status_code=400, detail=f"Unsupported model provider: {model_provider}")
except Exception as e:
print(f"Chat error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Mount frontend static files
if os.path.exists(FRONTEND_DIST_PATH):
# Mount assets (js, css, etc.)
app.mount("/assets", StaticFiles(directory=os.path.join(FRONTEND_DIST_PATH, "assets")), name="assets")
# Catch-all route for SPA
@app.get("/{full_path:path}")
async def serve_spa(full_path: str):
# API routes are already handled above because they are defined first
# Check if file exists in static folder (e.g. favicon.ico, logo.png)
possible_file = os.path.join(FRONTEND_DIST_PATH, full_path)
if os.path.isfile(possible_file):
return FileResponse(possible_file)
# Otherwise serve index.html
return FileResponse(os.path.join(FRONTEND_DIST_PATH, "index.html"))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)