# api_server.py import os import shutil import tempfile import uvicorn import json import logging import pandas as pd from pathlib import Path from typing import List, Optional from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware # ============================================== # Logging Configuration # ============================================== logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) try: from llm import generate_answer, LLMClient, OpenAILLMClient import config except ImportError as e: logger.warning(f"Could not import from llm.py: {e}") generate_answer = None LLMClient = None OpenAILLMClient = None app = FastAPI(title="Amazon Multimodal API") # ============================== # Global LLM Instance (Singleton) # ============================== LLM_INSTANCE = None def get_llm_instance(): """Get or create the global LLM instance""" global LLM_INSTANCE if LLM_INSTANCE is None: try: if config.USE_OPENAI and OpenAILLMClient is not None: # Use OpenAI GPT-4 logger.info(f"Initializing OpenAI {config.OPENAI_MODEL}...") LLM_INSTANCE = OpenAILLMClient( api_key=config.OPENAI_API_KEY, model=config.OPENAI_MODEL, max_tokens=config.OPENAI_MAX_TOKENS, temperature=config.OPENAI_TEMPERATURE ) logger.info(f"OpenAI {config.OPENAI_MODEL} loaded successfully!") elif LLMClient is not None: # Use local HuggingFace model logger.info(f"Initializing local model {config.LLM_MODEL} (this may take a few minutes)...") LLM_INSTANCE = LLMClient(model_name=config.LLM_MODEL) logger.info("Local LLM model loaded successfully!") else: raise ImportError("No LLM client available") except Exception as e: logger.error(f"Failed to load LLM model: {e}") raise return LLM_INSTANCE # ============================== # 0. Preload data (for Header statistics) # ============================== CSV_PATH = "amazon_multimodal_clean.csv" STATS = { "product_count": 0, "category_count": 0, "index_ready": False } def load_stats(): """Load CSV statistics on startup""" global STATS # Check if vector database index exists STATS["index_ready"] = os.path.isdir("chromadb_store") if os.path.exists(CSV_PATH): try: df = pd.read_csv(CSV_PATH) STATS["product_count"] = len(df) STATS["category_count"] = df["main_category"].nunique() if "main_category" in df.columns else 0 logger.info(f"Loaded Stats: {STATS}") except Exception as e: logger.error(f"Error loading CSV: {e}") else: logger.warning(f"CSV file not found at: {CSV_PATH}") # Execute loading on startup load_stats() # ============================== # 4. Startup Event: Build Index if Missing # ============================== @app.on_event("startup") async def startup_event(): """Initialize vector index on first startup if not exists""" import os from rag import build_index # Check if ChromaDB database file exists (not just the directory) db_file = os.path.join("chromadb_store", "chroma.sqlite3") if not os.path.exists(db_file): logger.info("=" * 60) logger.info("ChromaDB index not found. Building index...") logger.info("This may take 2-5 minutes on first startup.") logger.info("=" * 60) try: build_index( csv_path="amazon_multimodal_clean.csv", persist_dir="chromadb_store", max_items=None # Use full dataset ) logger.info("✅ Index built successfully!") except Exception as e: logger.error(f"❌ Failed to build index: {e}") else: logger.info("✅ ChromaDB index found. Ready to serve requests.") # Pre-initialize LLM to avoid cold start try: logger.info("Pre-initializing LLM instance...") get_llm_instance() logger.info("✅ LLM instance ready!") except Exception as e: logger.warning(f"⚠️ Failed to pre-initialize LLM: {e}") # ============================== # 1. CORS Configuration # ============================== app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins in development allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================== # 2. API Endpoints (must be defined BEFORE mounting static files!) # ============================== @app.get("/api/info") async def get_system_info(): """Return system statistics for frontend Header display""" # Re-check if index exists (it might be created during runtime) STATS["index_ready"] = os.path.isdir("chromadb_store") return STATS @app.get("/health") @app.head("/health") async def health_check(): """Health check endpoint for Docker and HF Spaces monitoring""" import os return { "status": "healthy", "index_ready": os.path.isdir("chromadb_store"), "llm_initialized": LLM_INSTANCE is not None } @app.post("/api/search") async def search( query: str = Form(""), mode: str = Form("multimodal"), history: str = Form("[]"), image: Optional[UploadFile] = File(None) ): """ Main search endpoint supporting text, image, and multimodal queries """ logger.info(f"Search request: mode={mode}, query_length={len(query)}, has_image={image is not None}") if not generate_answer: logger.error("Backend logic (llm.py) not loaded") raise HTTPException(status_code=500, detail="Service temporarily unavailable") temp_image_path = None if image: try: # Save uploaded image temporarily suffix = Path(image.filename).suffix or ".jpg" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: shutil.copyfileobj(image.file, tmp) temp_image_path = tmp.name logger.info(f"Saved uploaded image to: {temp_image_path}") except Exception as e: logger.error(f"Failed to save uploaded image: {e}") raise HTTPException(status_code=400, detail="Failed to process image upload") # Parse chat history from JSON string try: chat_history = json.loads(history) except Exception as e: logger.warning(f"Failed to parse chat history: {e}") chat_history = [] try: # Use the global LLM instance for better performance llm_instance = get_llm_instance() result = generate_answer( user_question=query, image_path=temp_image_path, mode=mode, chat_history=chat_history, llm_client=llm_instance ) logger.info(f"Search successful: returned {len(result.get('products', []))} products") processed_products = [] for p in result.get("products", []): raw_path = p.get("image_path", "") filename = os.path.basename(raw_path) # Construct accessible URL for frontend web_url = f"/product_images/{filename}" if filename else "" processed_products.append({ "name": p.get("name", "Unknown Product"), "category": p.get("category", "General"), "similarity": 1 - p.get("distance", 0.0), "image": web_url, }) return { "answer": result.get("answer", "No answer generated."), "products": processed_products, "retrieval_method": result.get("retrieval_method", mode), "status": "success" } except Exception as e: logger.error(f"Search API error: {str(e)}", exc_info=True) # Don't expose internal error details to client raise HTTPException(status_code=500, detail="An error occurred processing your search") finally: # Clean up temporary uploaded image if temp_image_path and os.path.exists(temp_image_path): try: os.unlink(temp_image_path) logger.debug(f"Cleaned up temporary file: {temp_image_path}") except Exception as e: logger.warning(f"Failed to clean up temporary file {temp_image_path}: {e}") # ============================== # 3. Static File Mounting # ============================== # A. Product images directory if os.path.exists("images"): app.mount("/product_images", StaticFiles(directory="images"), name="images") # B. Frontend static files - serve individual files to avoid blocking API routes from fastapi.responses import FileResponse @app.get("/") async def serve_index(): """Serve the main index.html""" return FileResponse("frontend/index.html") @app.get("/main.js") async def serve_main_js(): """Serve main.js""" return FileResponse("frontend/main.js") @app.get("/amazon-logo.png") async def serve_logo(): """Serve logo""" return FileResponse("frontend/amazon-logo.png") if __name__ == "__main__": import config uvicorn.run(app, host=config.API_HOST, port=config.API_PORT)