Easonwangzk's picture
Initial commit with Git LFS
ab26b91
raw
history blame
9.56 kB
# api_server.py
import os
import shutil
import tempfile
import uvicorn
import json
import logging
import pandas as pd
from pathlib import Path
from typing import List, Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
# ==============================================
# Logging Configuration
# ==============================================
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
try:
from llm import generate_answer, LLMClient, OpenAILLMClient
import config
except ImportError as e:
logger.warning(f"Could not import from llm.py: {e}")
generate_answer = None
LLMClient = None
OpenAILLMClient = None
app = FastAPI(title="Amazon Multimodal API")
# ==============================
# Global LLM Instance (Singleton)
# ==============================
LLM_INSTANCE = None
def get_llm_instance():
"""Get or create the global LLM instance"""
global LLM_INSTANCE
if LLM_INSTANCE is None:
try:
if config.USE_OPENAI and OpenAILLMClient is not None:
# Use OpenAI GPT-4
logger.info(f"Initializing OpenAI {config.OPENAI_MODEL}...")
LLM_INSTANCE = OpenAILLMClient(
api_key=config.OPENAI_API_KEY,
model=config.OPENAI_MODEL,
max_tokens=config.OPENAI_MAX_TOKENS,
temperature=config.OPENAI_TEMPERATURE
)
logger.info(f"OpenAI {config.OPENAI_MODEL} loaded successfully!")
elif LLMClient is not None:
# Use local HuggingFace model
logger.info(f"Initializing local model {config.LLM_MODEL} (this may take a few minutes)...")
LLM_INSTANCE = LLMClient(model_name=config.LLM_MODEL)
logger.info("Local LLM model loaded successfully!")
else:
raise ImportError("No LLM client available")
except Exception as e:
logger.error(f"Failed to load LLM model: {e}")
raise
return LLM_INSTANCE
# ==============================
# 0. Preload data (for Header statistics)
# ==============================
CSV_PATH = "amazon_multimodal_clean.csv"
STATS = {
"product_count": 0,
"category_count": 0,
"index_ready": False
}
def load_stats():
"""Load CSV statistics on startup"""
global STATS
# Check if vector database index exists
STATS["index_ready"] = os.path.isdir("chromadb_store")
if os.path.exists(CSV_PATH):
try:
df = pd.read_csv(CSV_PATH)
STATS["product_count"] = len(df)
STATS["category_count"] = df["main_category"].nunique() if "main_category" in df.columns else 0
logger.info(f"Loaded Stats: {STATS}")
except Exception as e:
logger.error(f"Error loading CSV: {e}")
else:
logger.warning(f"CSV file not found at: {CSV_PATH}")
# Execute loading on startup
load_stats()
# ==============================
# 4. Startup Event: Build Index if Missing
# ==============================
@app.on_event("startup")
async def startup_event():
"""Initialize vector index on first startup if not exists"""
import os
from rag import build_index
# Check if ChromaDB database file exists (not just the directory)
db_file = os.path.join("chromadb_store", "chroma.sqlite3")
if not os.path.exists(db_file):
logger.info("=" * 60)
logger.info("ChromaDB index not found. Building index...")
logger.info("This may take 2-5 minutes on first startup.")
logger.info("=" * 60)
try:
build_index(
csv_path="amazon_multimodal_clean.csv",
persist_dir="chromadb_store",
max_items=None # Use full dataset
)
logger.info("✅ Index built successfully!")
except Exception as e:
logger.error(f"❌ Failed to build index: {e}")
else:
logger.info("✅ ChromaDB index found. Ready to serve requests.")
# Pre-initialize LLM to avoid cold start
try:
logger.info("Pre-initializing LLM instance...")
get_llm_instance()
logger.info("✅ LLM instance ready!")
except Exception as e:
logger.warning(f"⚠️ Failed to pre-initialize LLM: {e}")
# ==============================
# 1. CORS Configuration
# ==============================
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins in development
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==============================
# 2. API Endpoints (must be defined BEFORE mounting static files!)
# ==============================
@app.get("/api/info")
async def get_system_info():
"""Return system statistics for frontend Header display"""
# Re-check if index exists (it might be created during runtime)
STATS["index_ready"] = os.path.isdir("chromadb_store")
return STATS
@app.get("/health")
@app.head("/health")
async def health_check():
"""Health check endpoint for Docker and HF Spaces monitoring"""
import os
return {
"status": "healthy",
"index_ready": os.path.isdir("chromadb_store"),
"llm_initialized": LLM_INSTANCE is not None
}
@app.post("/api/search")
async def search(
query: str = Form(""),
mode: str = Form("multimodal"),
history: str = Form("[]"),
image: Optional[UploadFile] = File(None)
):
"""
Main search endpoint supporting text, image, and multimodal queries
"""
logger.info(f"Search request: mode={mode}, query_length={len(query)}, has_image={image is not None}")
if not generate_answer:
logger.error("Backend logic (llm.py) not loaded")
raise HTTPException(status_code=500, detail="Service temporarily unavailable")
temp_image_path = None
if image:
try:
# Save uploaded image temporarily
suffix = Path(image.filename).suffix or ".jpg"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(image.file, tmp)
temp_image_path = tmp.name
logger.info(f"Saved uploaded image to: {temp_image_path}")
except Exception as e:
logger.error(f"Failed to save uploaded image: {e}")
raise HTTPException(status_code=400, detail="Failed to process image upload")
# Parse chat history from JSON string
try:
chat_history = json.loads(history)
except Exception as e:
logger.warning(f"Failed to parse chat history: {e}")
chat_history = []
try:
# Use the global LLM instance for better performance
llm_instance = get_llm_instance()
result = generate_answer(
user_question=query,
image_path=temp_image_path,
mode=mode,
chat_history=chat_history,
llm_client=llm_instance
)
logger.info(f"Search successful: returned {len(result.get('products', []))} products")
processed_products = []
for p in result.get("products", []):
raw_path = p.get("image_path", "")
filename = os.path.basename(raw_path)
# Construct accessible URL for frontend
web_url = f"/product_images/{filename}" if filename else ""
processed_products.append({
"name": p.get("name", "Unknown Product"),
"category": p.get("category", "General"),
"similarity": 1 - p.get("distance", 0.0),
"image": web_url,
})
return {
"answer": result.get("answer", "No answer generated."),
"products": processed_products,
"retrieval_method": result.get("retrieval_method", mode),
"status": "success"
}
except Exception as e:
logger.error(f"Search API error: {str(e)}", exc_info=True)
# Don't expose internal error details to client
raise HTTPException(status_code=500, detail="An error occurred processing your search")
finally:
# Clean up temporary uploaded image
if temp_image_path and os.path.exists(temp_image_path):
try:
os.unlink(temp_image_path)
logger.debug(f"Cleaned up temporary file: {temp_image_path}")
except Exception as e:
logger.warning(f"Failed to clean up temporary file {temp_image_path}: {e}")
# ==============================
# 3. Static File Mounting
# ==============================
# A. Product images directory
if os.path.exists("images"):
app.mount("/product_images", StaticFiles(directory="images"), name="images")
# B. Frontend static files - serve individual files to avoid blocking API routes
from fastapi.responses import FileResponse
@app.get("/")
async def serve_index():
"""Serve the main index.html"""
return FileResponse("frontend/index.html")
@app.get("/main.js")
async def serve_main_js():
"""Serve main.js"""
return FileResponse("frontend/main.js")
@app.get("/amazon-logo.png")
async def serve_logo():
"""Serve logo"""
return FileResponse("frontend/amazon-logo.png")
if __name__ == "__main__":
import config
uvicorn.run(app, host=config.API_HOST, port=config.API_PORT)