File size: 9,562 Bytes
ab26b91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# api_server.py
import os
import shutil
import tempfile
import uvicorn
import json
import logging
import pandas as pd
from pathlib import Path
from typing import List, Optional

from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware

# ==============================================
# Logging Configuration
# ==============================================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


try:
    from llm import generate_answer, LLMClient, OpenAILLMClient
    import config
except ImportError as e:
    logger.warning(f"Could not import from llm.py: {e}")
    generate_answer = None
    LLMClient = None
    OpenAILLMClient = None

app = FastAPI(title="Amazon Multimodal API")

# ==============================
# Global LLM Instance (Singleton)
# ==============================
LLM_INSTANCE = None

def get_llm_instance():
    """Get or create the global LLM instance"""
    global LLM_INSTANCE
    if LLM_INSTANCE is None:
        try:
            if config.USE_OPENAI and OpenAILLMClient is not None:
                # Use OpenAI GPT-4
                logger.info(f"Initializing OpenAI {config.OPENAI_MODEL}...")
                LLM_INSTANCE = OpenAILLMClient(
                    api_key=config.OPENAI_API_KEY,
                    model=config.OPENAI_MODEL,
                    max_tokens=config.OPENAI_MAX_TOKENS,
                    temperature=config.OPENAI_TEMPERATURE
                )
                logger.info(f"OpenAI {config.OPENAI_MODEL} loaded successfully!")
            elif LLMClient is not None:
                # Use local HuggingFace model
                logger.info(f"Initializing local model {config.LLM_MODEL} (this may take a few minutes)...")
                LLM_INSTANCE = LLMClient(model_name=config.LLM_MODEL)
                logger.info("Local LLM model loaded successfully!")
            else:
                raise ImportError("No LLM client available")
        except Exception as e:
            logger.error(f"Failed to load LLM model: {e}")
            raise
    return LLM_INSTANCE

# ==============================
# 0. Preload data (for Header statistics)
# ==============================
CSV_PATH = "amazon_multimodal_clean.csv"
STATS = {
    "product_count": 0,
    "category_count": 0,
    "index_ready": False
}

def load_stats():
    """Load CSV statistics on startup"""
    global STATS
    # Check if vector database index exists
    STATS["index_ready"] = os.path.isdir("chromadb_store")

    if os.path.exists(CSV_PATH):
        try:
            df = pd.read_csv(CSV_PATH)
            STATS["product_count"] = len(df)
            STATS["category_count"] = df["main_category"].nunique() if "main_category" in df.columns else 0
            logger.info(f"Loaded Stats: {STATS}")
        except Exception as e:
            logger.error(f"Error loading CSV: {e}")
    else:
        logger.warning(f"CSV file not found at: {CSV_PATH}")

# Execute loading on startup
load_stats()

# ==============================
# 4. Startup Event: Build Index if Missing
# ==============================
@app.on_event("startup")
async def startup_event():
    """Initialize vector index on first startup if not exists"""
    import os
    from rag import build_index

    # Check if ChromaDB database file exists (not just the directory)
    db_file = os.path.join("chromadb_store", "chroma.sqlite3")
    if not os.path.exists(db_file):
        logger.info("=" * 60)
        logger.info("ChromaDB index not found. Building index...")
        logger.info("This may take 2-5 minutes on first startup.")
        logger.info("=" * 60)

        try:
            build_index(
                csv_path="amazon_multimodal_clean.csv",
                persist_dir="chromadb_store",
                max_items=None  # Use full dataset
            )
            logger.info("βœ… Index built successfully!")
        except Exception as e:
            logger.error(f"❌ Failed to build index: {e}")
    else:
        logger.info("βœ… ChromaDB index found. Ready to serve requests.")

    # Pre-initialize LLM to avoid cold start
    try:
        logger.info("Pre-initializing LLM instance...")
        get_llm_instance()
        logger.info("βœ… LLM instance ready!")
    except Exception as e:
        logger.warning(f"⚠️ Failed to pre-initialize LLM: {e}")

# ==============================
# 1. CORS Configuration
# ==============================
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],          # Allow all origins in development
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ==============================
# 2. API Endpoints (must be defined BEFORE mounting static files!)
# ==============================

@app.get("/api/info")
async def get_system_info():
    """Return system statistics for frontend Header display"""
    # Re-check if index exists (it might be created during runtime)
    STATS["index_ready"] = os.path.isdir("chromadb_store")
    return STATS


@app.get("/health")
@app.head("/health")
async def health_check():
    """Health check endpoint for Docker and HF Spaces monitoring"""
    import os
    return {
        "status": "healthy",
        "index_ready": os.path.isdir("chromadb_store"),
        "llm_initialized": LLM_INSTANCE is not None
    }


@app.post("/api/search")
async def search(
    query: str = Form(""),
    mode: str = Form("multimodal"),
    history: str = Form("[]"),
    image: Optional[UploadFile] = File(None)
):
    """
    Main search endpoint supporting text, image, and multimodal queries
    """
    logger.info(f"Search request: mode={mode}, query_length={len(query)}, has_image={image is not None}")

    if not generate_answer:
        logger.error("Backend logic (llm.py) not loaded")
        raise HTTPException(status_code=500, detail="Service temporarily unavailable")

    temp_image_path = None
    if image:
        try:
            # Save uploaded image temporarily
            suffix = Path(image.filename).suffix or ".jpg"
            with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
                shutil.copyfileobj(image.file, tmp)
                temp_image_path = tmp.name
            logger.info(f"Saved uploaded image to: {temp_image_path}")
        except Exception as e:
            logger.error(f"Failed to save uploaded image: {e}")
            raise HTTPException(status_code=400, detail="Failed to process image upload")

    # Parse chat history from JSON string
    try:
        chat_history = json.loads(history)
    except Exception as e:
        logger.warning(f"Failed to parse chat history: {e}")
        chat_history = []

    try:
        # Use the global LLM instance for better performance
        llm_instance = get_llm_instance()
        result = generate_answer(
            user_question=query,
            image_path=temp_image_path,
            mode=mode,
            chat_history=chat_history,
            llm_client=llm_instance
        )
        logger.info(f"Search successful: returned {len(result.get('products', []))} products")
        
        processed_products = []
        for p in result.get("products", []):
            raw_path = p.get("image_path", "")
            filename = os.path.basename(raw_path)
            # Construct accessible URL for frontend
            web_url = f"/product_images/{filename}" if filename else ""
            
            processed_products.append({
                "name": p.get("name", "Unknown Product"),
                "category": p.get("category", "General"),
                "similarity": 1 - p.get("distance", 0.0),
                "image": web_url,
            })

        return {
            "answer": result.get("answer", "No answer generated."),
            "products": processed_products,
            "retrieval_method": result.get("retrieval_method", mode),
            "status": "success"
        }

    except Exception as e:
        logger.error(f"Search API error: {str(e)}", exc_info=True)
        # Don't expose internal error details to client
        raise HTTPException(status_code=500, detail="An error occurred processing your search")

    finally:
        # Clean up temporary uploaded image
        if temp_image_path and os.path.exists(temp_image_path):
            try:
                os.unlink(temp_image_path)
                logger.debug(f"Cleaned up temporary file: {temp_image_path}")
            except Exception as e:
                logger.warning(f"Failed to clean up temporary file {temp_image_path}: {e}")


# ==============================
# 3. Static File Mounting
# ==============================

# A. Product images directory
if os.path.exists("images"):
    app.mount("/product_images", StaticFiles(directory="images"), name="images")

# B. Frontend static files - serve individual files to avoid blocking API routes
from fastapi.responses import FileResponse

@app.get("/")
async def serve_index():
    """Serve the main index.html"""
    return FileResponse("frontend/index.html")

@app.get("/main.js")
async def serve_main_js():
    """Serve main.js"""
    return FileResponse("frontend/main.js")

@app.get("/amazon-logo.png")
async def serve_logo():
    """Serve logo"""
    return FileResponse("frontend/amazon-logo.png")


if __name__ == "__main__":
    import config
    uvicorn.run(app, host=config.API_HOST, port=config.API_PORT)