File size: 9,562 Bytes
ab26b91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
# api_server.py
import os
import shutil
import tempfile
import uvicorn
import json
import logging
import pandas as pd
from pathlib import Path
from typing import List, Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
# ==============================================
# Logging Configuration
# ==============================================
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
try:
from llm import generate_answer, LLMClient, OpenAILLMClient
import config
except ImportError as e:
logger.warning(f"Could not import from llm.py: {e}")
generate_answer = None
LLMClient = None
OpenAILLMClient = None
app = FastAPI(title="Amazon Multimodal API")
# ==============================
# Global LLM Instance (Singleton)
# ==============================
LLM_INSTANCE = None
def get_llm_instance():
"""Get or create the global LLM instance"""
global LLM_INSTANCE
if LLM_INSTANCE is None:
try:
if config.USE_OPENAI and OpenAILLMClient is not None:
# Use OpenAI GPT-4
logger.info(f"Initializing OpenAI {config.OPENAI_MODEL}...")
LLM_INSTANCE = OpenAILLMClient(
api_key=config.OPENAI_API_KEY,
model=config.OPENAI_MODEL,
max_tokens=config.OPENAI_MAX_TOKENS,
temperature=config.OPENAI_TEMPERATURE
)
logger.info(f"OpenAI {config.OPENAI_MODEL} loaded successfully!")
elif LLMClient is not None:
# Use local HuggingFace model
logger.info(f"Initializing local model {config.LLM_MODEL} (this may take a few minutes)...")
LLM_INSTANCE = LLMClient(model_name=config.LLM_MODEL)
logger.info("Local LLM model loaded successfully!")
else:
raise ImportError("No LLM client available")
except Exception as e:
logger.error(f"Failed to load LLM model: {e}")
raise
return LLM_INSTANCE
# ==============================
# 0. Preload data (for Header statistics)
# ==============================
CSV_PATH = "amazon_multimodal_clean.csv"
STATS = {
"product_count": 0,
"category_count": 0,
"index_ready": False
}
def load_stats():
"""Load CSV statistics on startup"""
global STATS
# Check if vector database index exists
STATS["index_ready"] = os.path.isdir("chromadb_store")
if os.path.exists(CSV_PATH):
try:
df = pd.read_csv(CSV_PATH)
STATS["product_count"] = len(df)
STATS["category_count"] = df["main_category"].nunique() if "main_category" in df.columns else 0
logger.info(f"Loaded Stats: {STATS}")
except Exception as e:
logger.error(f"Error loading CSV: {e}")
else:
logger.warning(f"CSV file not found at: {CSV_PATH}")
# Execute loading on startup
load_stats()
# ==============================
# 4. Startup Event: Build Index if Missing
# ==============================
@app.on_event("startup")
async def startup_event():
"""Initialize vector index on first startup if not exists"""
import os
from rag import build_index
# Check if ChromaDB database file exists (not just the directory)
db_file = os.path.join("chromadb_store", "chroma.sqlite3")
if not os.path.exists(db_file):
logger.info("=" * 60)
logger.info("ChromaDB index not found. Building index...")
logger.info("This may take 2-5 minutes on first startup.")
logger.info("=" * 60)
try:
build_index(
csv_path="amazon_multimodal_clean.csv",
persist_dir="chromadb_store",
max_items=None # Use full dataset
)
logger.info("β
Index built successfully!")
except Exception as e:
logger.error(f"β Failed to build index: {e}")
else:
logger.info("β
ChromaDB index found. Ready to serve requests.")
# Pre-initialize LLM to avoid cold start
try:
logger.info("Pre-initializing LLM instance...")
get_llm_instance()
logger.info("β
LLM instance ready!")
except Exception as e:
logger.warning(f"β οΈ Failed to pre-initialize LLM: {e}")
# ==============================
# 1. CORS Configuration
# ==============================
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins in development
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==============================
# 2. API Endpoints (must be defined BEFORE mounting static files!)
# ==============================
@app.get("/api/info")
async def get_system_info():
"""Return system statistics for frontend Header display"""
# Re-check if index exists (it might be created during runtime)
STATS["index_ready"] = os.path.isdir("chromadb_store")
return STATS
@app.get("/health")
@app.head("/health")
async def health_check():
"""Health check endpoint for Docker and HF Spaces monitoring"""
import os
return {
"status": "healthy",
"index_ready": os.path.isdir("chromadb_store"),
"llm_initialized": LLM_INSTANCE is not None
}
@app.post("/api/search")
async def search(
query: str = Form(""),
mode: str = Form("multimodal"),
history: str = Form("[]"),
image: Optional[UploadFile] = File(None)
):
"""
Main search endpoint supporting text, image, and multimodal queries
"""
logger.info(f"Search request: mode={mode}, query_length={len(query)}, has_image={image is not None}")
if not generate_answer:
logger.error("Backend logic (llm.py) not loaded")
raise HTTPException(status_code=500, detail="Service temporarily unavailable")
temp_image_path = None
if image:
try:
# Save uploaded image temporarily
suffix = Path(image.filename).suffix or ".jpg"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(image.file, tmp)
temp_image_path = tmp.name
logger.info(f"Saved uploaded image to: {temp_image_path}")
except Exception as e:
logger.error(f"Failed to save uploaded image: {e}")
raise HTTPException(status_code=400, detail="Failed to process image upload")
# Parse chat history from JSON string
try:
chat_history = json.loads(history)
except Exception as e:
logger.warning(f"Failed to parse chat history: {e}")
chat_history = []
try:
# Use the global LLM instance for better performance
llm_instance = get_llm_instance()
result = generate_answer(
user_question=query,
image_path=temp_image_path,
mode=mode,
chat_history=chat_history,
llm_client=llm_instance
)
logger.info(f"Search successful: returned {len(result.get('products', []))} products")
processed_products = []
for p in result.get("products", []):
raw_path = p.get("image_path", "")
filename = os.path.basename(raw_path)
# Construct accessible URL for frontend
web_url = f"/product_images/{filename}" if filename else ""
processed_products.append({
"name": p.get("name", "Unknown Product"),
"category": p.get("category", "General"),
"similarity": 1 - p.get("distance", 0.0),
"image": web_url,
})
return {
"answer": result.get("answer", "No answer generated."),
"products": processed_products,
"retrieval_method": result.get("retrieval_method", mode),
"status": "success"
}
except Exception as e:
logger.error(f"Search API error: {str(e)}", exc_info=True)
# Don't expose internal error details to client
raise HTTPException(status_code=500, detail="An error occurred processing your search")
finally:
# Clean up temporary uploaded image
if temp_image_path and os.path.exists(temp_image_path):
try:
os.unlink(temp_image_path)
logger.debug(f"Cleaned up temporary file: {temp_image_path}")
except Exception as e:
logger.warning(f"Failed to clean up temporary file {temp_image_path}: {e}")
# ==============================
# 3. Static File Mounting
# ==============================
# A. Product images directory
if os.path.exists("images"):
app.mount("/product_images", StaticFiles(directory="images"), name="images")
# B. Frontend static files - serve individual files to avoid blocking API routes
from fastapi.responses import FileResponse
@app.get("/")
async def serve_index():
"""Serve the main index.html"""
return FileResponse("frontend/index.html")
@app.get("/main.js")
async def serve_main_js():
"""Serve main.js"""
return FileResponse("frontend/main.js")
@app.get("/amazon-logo.png")
async def serve_logo():
"""Serve logo"""
return FileResponse("frontend/amazon-logo.png")
if __name__ == "__main__":
import config
uvicorn.run(app, host=config.API_HOST, port=config.API_PORT) |