|
|
|
|
|
import os |
|
|
import shutil |
|
|
import tempfile |
|
|
import uvicorn |
|
|
import json |
|
|
import logging |
|
|
import pandas as pd |
|
|
from pathlib import Path |
|
|
from typing import List, Optional |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
from llm import generate_answer, LLMClient, OpenAILLMClient |
|
|
import config |
|
|
except ImportError as e: |
|
|
logger.warning(f"Could not import from llm.py: {e}") |
|
|
generate_answer = None |
|
|
LLMClient = None |
|
|
OpenAILLMClient = None |
|
|
|
|
|
app = FastAPI(title="Amazon Multimodal API") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLM_INSTANCE = None |
|
|
|
|
|
def get_llm_instance(): |
|
|
"""Get or create the global LLM instance""" |
|
|
global LLM_INSTANCE |
|
|
if LLM_INSTANCE is None: |
|
|
try: |
|
|
if config.USE_OPENAI and OpenAILLMClient is not None: |
|
|
|
|
|
logger.info(f"Initializing OpenAI {config.OPENAI_MODEL}...") |
|
|
LLM_INSTANCE = OpenAILLMClient( |
|
|
api_key=config.OPENAI_API_KEY, |
|
|
model=config.OPENAI_MODEL, |
|
|
max_tokens=config.OPENAI_MAX_TOKENS, |
|
|
temperature=config.OPENAI_TEMPERATURE |
|
|
) |
|
|
logger.info(f"OpenAI {config.OPENAI_MODEL} loaded successfully!") |
|
|
elif LLMClient is not None: |
|
|
|
|
|
logger.info(f"Initializing local model {config.LLM_MODEL} (this may take a few minutes)...") |
|
|
LLM_INSTANCE = LLMClient(model_name=config.LLM_MODEL) |
|
|
logger.info("Local LLM model loaded successfully!") |
|
|
else: |
|
|
raise ImportError("No LLM client available") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load LLM model: {e}") |
|
|
raise |
|
|
return LLM_INSTANCE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CSV_PATH = "amazon_multimodal_clean.csv" |
|
|
STATS = { |
|
|
"product_count": 0, |
|
|
"category_count": 0, |
|
|
"index_ready": False |
|
|
} |
|
|
|
|
|
def load_stats(): |
|
|
"""Load CSV statistics on startup""" |
|
|
global STATS |
|
|
|
|
|
STATS["index_ready"] = os.path.isdir("chromadb_store") |
|
|
|
|
|
if os.path.exists(CSV_PATH): |
|
|
try: |
|
|
df = pd.read_csv(CSV_PATH) |
|
|
STATS["product_count"] = len(df) |
|
|
STATS["category_count"] = df["main_category"].nunique() if "main_category" in df.columns else 0 |
|
|
logger.info(f"Loaded Stats: {STATS}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading CSV: {e}") |
|
|
else: |
|
|
logger.warning(f"CSV file not found at: {CSV_PATH}") |
|
|
|
|
|
|
|
|
load_stats() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize vector index on first startup if not exists""" |
|
|
import os |
|
|
from rag import build_index |
|
|
|
|
|
|
|
|
db_file = os.path.join("chromadb_store", "chroma.sqlite3") |
|
|
if not os.path.exists(db_file): |
|
|
logger.info("=" * 60) |
|
|
logger.info("ChromaDB index not found. Building index...") |
|
|
logger.info("This may take 2-5 minutes on first startup.") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
try: |
|
|
build_index( |
|
|
csv_path="amazon_multimodal_clean.csv", |
|
|
persist_dir="chromadb_store", |
|
|
max_items=None |
|
|
) |
|
|
logger.info("✅ Index built successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to build index: {e}") |
|
|
else: |
|
|
logger.info("✅ ChromaDB index found. Ready to serve requests.") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Pre-initializing LLM instance...") |
|
|
get_llm_instance() |
|
|
logger.info("✅ LLM instance ready!") |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Failed to pre-initialize LLM: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/info") |
|
|
async def get_system_info(): |
|
|
"""Return system statistics for frontend Header display""" |
|
|
|
|
|
STATS["index_ready"] = os.path.isdir("chromadb_store") |
|
|
return STATS |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
@app.head("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint for Docker and HF Spaces monitoring""" |
|
|
import os |
|
|
return { |
|
|
"status": "healthy", |
|
|
"index_ready": os.path.isdir("chromadb_store"), |
|
|
"llm_initialized": LLM_INSTANCE is not None |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/api/search") |
|
|
async def search( |
|
|
query: str = Form(""), |
|
|
mode: str = Form("multimodal"), |
|
|
history: str = Form("[]"), |
|
|
image: Optional[UploadFile] = File(None) |
|
|
): |
|
|
""" |
|
|
Main search endpoint supporting text, image, and multimodal queries |
|
|
""" |
|
|
logger.info(f"Search request: mode={mode}, query_length={len(query)}, has_image={image is not None}") |
|
|
|
|
|
if not generate_answer: |
|
|
logger.error("Backend logic (llm.py) not loaded") |
|
|
raise HTTPException(status_code=500, detail="Service temporarily unavailable") |
|
|
|
|
|
temp_image_path = None |
|
|
if image: |
|
|
try: |
|
|
|
|
|
suffix = Path(image.filename).suffix or ".jpg" |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
|
|
shutil.copyfileobj(image.file, tmp) |
|
|
temp_image_path = tmp.name |
|
|
logger.info(f"Saved uploaded image to: {temp_image_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save uploaded image: {e}") |
|
|
raise HTTPException(status_code=400, detail="Failed to process image upload") |
|
|
|
|
|
|
|
|
try: |
|
|
chat_history = json.loads(history) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to parse chat history: {e}") |
|
|
chat_history = [] |
|
|
|
|
|
try: |
|
|
|
|
|
llm_instance = get_llm_instance() |
|
|
result = generate_answer( |
|
|
user_question=query, |
|
|
image_path=temp_image_path, |
|
|
mode=mode, |
|
|
chat_history=chat_history, |
|
|
llm_client=llm_instance |
|
|
) |
|
|
logger.info(f"Search successful: returned {len(result.get('products', []))} products") |
|
|
|
|
|
processed_products = [] |
|
|
for p in result.get("products", []): |
|
|
raw_path = p.get("image_path", "") |
|
|
filename = os.path.basename(raw_path) |
|
|
|
|
|
web_url = f"/product_images/{filename}" if filename else "" |
|
|
|
|
|
processed_products.append({ |
|
|
"name": p.get("name", "Unknown Product"), |
|
|
"category": p.get("category", "General"), |
|
|
"similarity": 1 - p.get("distance", 0.0), |
|
|
"image": web_url, |
|
|
}) |
|
|
|
|
|
return { |
|
|
"answer": result.get("answer", "No answer generated."), |
|
|
"products": processed_products, |
|
|
"retrieval_method": result.get("retrieval_method", mode), |
|
|
"status": "success" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Search API error: {str(e)}", exc_info=True) |
|
|
|
|
|
raise HTTPException(status_code=500, detail="An error occurred processing your search") |
|
|
|
|
|
finally: |
|
|
|
|
|
if temp_image_path and os.path.exists(temp_image_path): |
|
|
try: |
|
|
os.unlink(temp_image_path) |
|
|
logger.debug(f"Cleaned up temporary file: {temp_image_path}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to clean up temporary file {temp_image_path}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists("images"): |
|
|
app.mount("/product_images", StaticFiles(directory="images"), name="images") |
|
|
|
|
|
|
|
|
from fastapi.responses import FileResponse |
|
|
|
|
|
@app.get("/") |
|
|
async def serve_index(): |
|
|
"""Serve the main index.html""" |
|
|
return FileResponse("frontend/index.html") |
|
|
|
|
|
@app.get("/main.js") |
|
|
async def serve_main_js(): |
|
|
"""Serve main.js""" |
|
|
return FileResponse("frontend/main.js") |
|
|
|
|
|
@app.get("/amazon-logo.png") |
|
|
async def serve_logo(): |
|
|
"""Serve logo""" |
|
|
return FileResponse("frontend/amazon-logo.png") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import config |
|
|
uvicorn.run(app, host=config.API_HOST, port=config.API_PORT) |