Akash-Dragon commited on
Commit
be197b9
Β·
1 Parent(s): d7f6907

Fix absolute paths for HF Spaces deployment

Browse files
Files changed (1) hide show
  1. app.py +1169 -1168
app.py CHANGED
@@ -1,1168 +1,1169 @@
1
- # %%
2
- # ============================================================
3
- # JEWELLERY MULTIMODAL SEARCH BACKEND (FASTAPI)
4
- # ============================================================
5
-
6
- # %%
7
- # ============================================================
8
- # IMPORTS
9
- # ============================================================
10
-
11
- import os
12
- import json
13
- from typing import List, Dict
14
-
15
- import torch
16
- import clip
17
- import numpy as np
18
- import chromadb
19
-
20
- from fastapi import FastAPI, HTTPException, File, UploadFile
21
- from fastapi.middleware.cors import CORSMiddleware
22
- from fastapi.responses import FileResponse
23
- from pydantic import BaseModel
24
-
25
- from openai import OpenAI
26
- from dotenv import load_dotenv
27
- from sentence_transformers import CrossEncoder
28
-
29
- import base64
30
- import requests
31
- from PIL import Image
32
- import io
33
-
34
- # Load environment variables from .env file
35
- load_dotenv()
36
-
37
- # %%
38
- # ============================================================
39
- # CONFIG
40
- # ============================================================
41
-
42
- # Auto-detect if running in Docker (HF Spaces) or locally
43
- if os.path.exists("/app/data"):
44
- # Running in Docker (Hugging Face Spaces)
45
- BASE_DIR = "/app"
46
- else:
47
- # Running locally - use script directory as base
48
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
49
-
50
- CHROMA_PATH = os.path.join(BASE_DIR, "chroma_primary")
51
- DATA_DIR = os.path.join(BASE_DIR, "data", "tanishq")
52
- IMAGE_DIR = os.path.join(DATA_DIR, "images")
53
- BLIP_CAPTIONS_PATH = os.path.join(DATA_DIR, "blip_captions.json")
54
-
55
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
56
-
57
- # %%
58
- # ============================================================
59
- # LAZY MODEL LOADING (Reduces cold start time)
60
- # ============================================================
61
-
62
- # Global model references (loaded on first use)
63
- clip_model = None
64
- cross_encoder = None
65
-
66
- def get_clip_model():
67
- """Lazy load CLIP model on first use"""
68
- global clip_model
69
- if clip_model is None:
70
- print("πŸ”Ή Loading CLIP model...")
71
- model, _ = clip.load("ViT-B/16", device=DEVICE)
72
- model.eval()
73
- clip_model = model
74
- print("βœ… CLIP model loaded")
75
- return clip_model
76
-
77
- def get_cross_encoder():
78
- """Lazy load Cross-Encoder on first use"""
79
- global cross_encoder
80
- if cross_encoder is None:
81
- print("πŸ”Ή Loading Cross-Encoder for re-ranking...")
82
- cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
83
- print("βœ… Cross-Encoder loaded")
84
- return cross_encoder
85
-
86
- # %%
87
- print("πŸ”Ή Loading BLIP captions...")
88
- with open(BLIP_CAPTIONS_PATH, "r") as f:
89
- BLIP_CAPTIONS = json.load(f)
90
-
91
- # %%
92
- # ============================================================
93
- # INITIALIZE GROQ LLM CLIENT
94
- # ============================================================
95
-
96
- GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
97
- if GROQ_API_KEY:
98
- print("πŸ”Ή Initializing Groq LLM client...")
99
- groq_client = OpenAI(
100
- base_url="https://api.groq.com/openai/v1",
101
- api_key=GROQ_API_KEY
102
- )
103
- else:
104
- groq_client = None
105
- print("⚠️ GROQ_API_KEY not set; LLM features disabled (fallbacks enabled)")
106
-
107
- # NVIDIA OCR API configuration
108
- NVIDIA_API_KEY = os.environ.get("NVIDIA_API_KEY")
109
- NVIDIA_OCR_URL = "https://ai.api.nvidia.com/v1/cv/nvidia/nemoretriever-ocr-v1"
110
-
111
- # Fallback OCR configuration
112
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
113
-
114
- # %%
115
- # ============================================================
116
- # LOAD CHROMA (PERSISTED DB)
117
- # ============================================================
118
-
119
- print("πŸ”Ή Connecting to Chroma DB...")
120
- chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
121
-
122
- image_collection = chroma_client.get_collection("jewelry_images")
123
- metadata_collection = chroma_client.get_collection("jewelry_metadata")
124
-
125
- print(
126
- "βœ… Chroma loaded | Images:",
127
- image_collection.count(),
128
- "| Metadata:",
129
- metadata_collection.count()
130
- )
131
-
132
- # %%
133
- # ============================================================
134
- # FASTAPI APP
135
- # ============================================================
136
-
137
- app = FastAPI(title="Jewellery Multimodal Search")
138
-
139
- app.add_middleware(
140
- CORSMiddleware,
141
- allow_origins=[
142
- "http://localhost:5173",
143
- "https://tanishq-rag-capstone-1lj2x4y1v-akash-aimls-projects.vercel.app", # Vercel preview
144
- "https://*.vercel.app", # All Vercel deployments
145
- "https://*.ngrok-free.app", # Allow ngrok tunnels
146
- "https://*.ngrok-free.dev", # Allow ngrok tunnels (new domain)
147
- "https://*.ngrok.io", # Allow ngrok tunnels (legacy)
148
- ],
149
- allow_credentials=True,
150
- allow_methods=["*"],
151
- allow_headers=["*"],
152
- )
153
-
154
- # %%
155
- # ============================================================
156
- # MIDDLEWARE FOR HF SPACES OPTIMIZATION
157
- # ============================================================
158
-
159
- import asyncio
160
- from starlette.requests import Request
161
-
162
- @app.middleware("http")
163
- async def add_optimizations(request: Request, call_next):
164
- """Add upload size limits and request timeouts"""
165
-
166
- # Limit upload size to 5MB
167
- if request.method == "POST":
168
- content_length = request.headers.get("content-length")
169
- if content_length and int(content_length) > 5 * 1024 * 1024: # 5MB
170
- raise HTTPException(status_code=413, detail="File too large (max 5MB)")
171
-
172
- # Add request timeout (60s for local dev/slower machines)
173
- try:
174
- response = await asyncio.wait_for(call_next(request), timeout=60.0)
175
- return response
176
- except asyncio.TimeoutError:
177
- raise HTTPException(status_code=504, detail="Request timeout (max 60s)")
178
-
179
- # %%
180
- # ============================================================
181
- # REQUEST / RESPONSE SCHEMAS
182
- # ============================================================
183
-
184
- class TextSearchRequest(BaseModel):
185
- query: str
186
- filters: Dict[str, str] = None # Explicit UI filters (e.g. {"metal": "gold"})
187
- top_k: int = 5
188
- use_reranking: bool = True # Toggle cross-encoder (3x faster when False)
189
- use_explanations: bool = True # Toggle LLM explanations (500ms+ faster when False)
190
-
191
-
192
- class SimilarSearchRequest(BaseModel):
193
- image_id: str
194
- top_k: int = 5
195
-
196
- # %%
197
- # ============================================================
198
- # CLIP QUERY ENCODING (TEXT ONLY)
199
- # ============================================================
200
-
201
- def encode_text_clip(text: str) -> np.ndarray:
202
- """Encode text using CLIP with memory cleanup"""
203
- model = get_clip_model()
204
- tokens = clip.tokenize([text]).to(DEVICE)
205
- with torch.no_grad():
206
- emb = model.encode_text(tokens)
207
- emb = emb / emb.norm(dim=-1, keepdim=True)
208
- result = emb.cpu().numpy()[0]
209
-
210
- # Memory cleanup
211
- del tokens, emb
212
- if DEVICE == "cuda":
213
- torch.cuda.empty_cache()
214
-
215
- return result
216
-
217
- # %%
218
- # ============================================================
219
- # INTENT & ATTRIBUTE DETECTION WITH LLM (STRUCTURED)
220
- # ============================================================
221
-
222
- def detect_intent_and_attributes(query: str) -> Dict:
223
- """
224
- Extract search attributes and exclusions from query using LLM with fixed schema.
225
-
226
- Returns:
227
- {
228
- "intent": "search",
229
- "attributes": {category, metal, primary_stone}, # Items to INCLUDE
230
- "exclusions": {category, metal, primary_stone} # Items to EXCLUDE
231
- }
232
- """
233
-
234
- prompt = f"""Extract jewellery search attributes from this query.
235
-
236
- Query: "{query}"
237
-
238
- Return ONLY valid JSON with this exact schema:
239
- {{
240
- "intent": "search",
241
- "attributes": {{
242
- "category": "ring|necklace|earring|bracelet|null",
243
- "metal": "gold|silver|platinum|null",
244
- "primary_stone": "diamond|pearl|ruby|emerald|sapphire|null"
245
- }},
246
- "exclusions": {{
247
- "category": "ring|necklace|earring|bracelet|null",
248
- "metal": "gold|silver|platinum|null",
249
- "primary_stone": "diamond|pearl|ruby|emerald|sapphire|null"
250
- }}
251
- }}
252
-
253
- Rules:
254
- - "attributes" = what to INCLUDE (positive filters)
255
- - "exclusions" = what to EXCLUDE (negative filters)
256
- - Use null for unspecified fields
257
- - Detect negations: "no", "without", "not", "plain", "-free"
258
-
259
- Examples:
260
-
261
- Query: "gold ring with diamonds"
262
- {{"intent": "search", "attributes": {{"category": "ring", "metal": "gold", "primary_stone": "diamond"}}, "exclusions": {{"category": null, "metal": null, "primary_stone": null}}}}
263
-
264
- Query: "ring with no diamonds"
265
- {{"intent": "search", "attributes": {{"category": "ring", "metal": null, "primary_stone": null}}, "exclusions": {{"category": null, "metal": null, "primary_stone": "diamond"}}}}
266
-
267
- Query: "plain silver necklace"
268
- {{"intent": "search", "attributes": {{"category": "necklace", "metal": "silver", "primary_stone": null}}, "exclusions": {{"category": null, "metal": null, "primary_stone": "any"}}}}
269
-
270
- Query: "gold necklace without pearls"
271
- {{"intent": "search", "attributes": {{"category": "necklace", "metal": "gold", "primary_stone": null}}, "exclusions": {{"category": null, "metal": null, "primary_stone": "pearl"}}}}
272
-
273
- Return ONLY the JSON, no explanation."""
274
-
275
- def simple_fallback() -> Dict:
276
- q = query.lower()
277
- attrs = {}
278
-
279
- if "necklace" in q:
280
- attrs["category"] = "necklace"
281
- elif "ring" in q:
282
- attrs["category"] = "ring"
283
-
284
- if "gold" in q:
285
- attrs["metal"] = "gold"
286
- elif "silver" in q:
287
- attrs["metal"] = "silver"
288
-
289
- if "pearl" in q:
290
- attrs["primary_stone"] = "pearl"
291
- elif "diamond" in q:
292
- attrs["primary_stone"] = "diamond"
293
-
294
- return {
295
- "intent": "search",
296
- "attributes": attrs,
297
- "exclusions": {}
298
- }
299
-
300
- if groq_client is None:
301
- return simple_fallback()
302
-
303
- try:
304
- response = groq_client.chat.completions.create(
305
- model="llama-3.3-70b-versatile",
306
- messages=[{"role": "user", "content": prompt}],
307
- temperature=0.1,
308
- max_tokens=300
309
- )
310
-
311
- result_text = response.choices[0].message.content.strip()
312
-
313
- # Extract JSON from response (handle markdown code blocks)
314
- if "```json" in result_text:
315
- result_text = result_text.split("```json")[1].split("```")[0].strip()
316
- elif "```" in result_text:
317
- result_text = result_text.split("```")[1].split("```")[0].strip()
318
-
319
- result = json.loads(result_text)
320
-
321
- # Clean null values
322
- result["attributes"] = {k: v for k, v in result.get("attributes", {}).items() if v and v != "null"}
323
- result["exclusions"] = {k: v for k, v in result.get("exclusions", {}).items() if v and v != "null"}
324
-
325
- return result
326
-
327
- except Exception as e:
328
- print(f"⚠️ LLM extraction failed: {e}, falling back to simple extraction")
329
- return simple_fallback()
330
-
331
-
332
- # %%
333
- # ============================================================
334
- # VISUAL RETRIEVAL (NO LANGCHAIN)
335
- # ============================================================
336
-
337
- def retrieve_visual_candidates(query_text: str, k: int = 100, where_filter: Dict = None):
338
- q_emb = encode_text_clip(query_text)
339
-
340
- # Use Chroma's built-in filtering if provided
341
- res = image_collection.query(
342
- query_embeddings=[q_emb],
343
- n_results=k,
344
- where=where_filter
345
- )
346
-
347
- if not res["ids"] or not res["ids"][0]:
348
- return []
349
-
350
- return [
351
- {
352
- "image_id": img_id,
353
- "visual_score": dist
354
- }
355
- for img_id, dist in zip(res["ids"][0], res["distances"][0])
356
- ]
357
-
358
- # %%
359
- # ============================================================
360
- # METADATA SCORING REFINEMENTS
361
- # ============================================================
362
-
363
- def adaptive_alpha(query_attrs: Dict) -> float:
364
- return 0.1 + 0.1 * len(query_attrs)
365
-
366
-
367
- def refined_metadata_adjustment(meta: Dict, query_attrs: Dict) -> float:
368
- score = 0.0
369
-
370
- for attr, q_val in query_attrs.items():
371
- m_val = meta.get(attr)
372
- conf = meta.get(f"confidence_{attr}", 0.0)
373
-
374
- if m_val == q_val:
375
- score += conf
376
- elif conf > 0.6:
377
- score -= 0.3 * conf
378
-
379
- return score
380
-
381
-
382
- def apply_metadata_boost(candidates: List[Dict], query_attrs: Dict, exclusions: Dict = None):
383
- """
384
- Rank candidates by combining visual similarity with metadata matching.
385
- HARD FILTER out excluded items completely.
386
-
387
- Args:
388
- candidates: List of {image_id, visual_score}
389
- query_attrs: Attributes to INCLUDE (boost matching items)
390
- exclusions: Attributes to EXCLUDE (FILTER OUT completely)
391
- """
392
- if exclusions is None:
393
- exclusions = {}
394
-
395
- alpha = adaptive_alpha(query_attrs)
396
- ranked = []
397
-
398
- for c in candidates:
399
- meta = metadata_collection.get(
400
- ids=[c["image_id"]],
401
- include=["metadatas"]
402
- )["metadatas"][0]
403
-
404
- # HARD FILTER: Skip items that match exclusions
405
- should_exclude = False
406
- for attr, excluded_value in exclusions.items():
407
- meta_value = meta.get(attr)
408
-
409
- # Handle "any" exclusion (e.g., "plain" means no stones at all)
410
- if excluded_value == "any":
411
- # Exclude if has ANY stone (not unknown/null)
412
- if meta_value and meta_value not in ["unknown", "null", ""]:
413
- should_exclude = True
414
- print(f"🚫 Excluding {c['image_id']}: has {attr}={meta_value} (want none)")
415
- break
416
- # Handle specific exclusion
417
- elif meta_value == excluded_value:
418
- should_exclude = True
419
- print(f"🚫 Excluding {c['image_id']}: has {attr}={meta_value} (excluded)")
420
- break
421
-
422
- # Skip this item if it matches any exclusion
423
- if should_exclude:
424
- continue
425
-
426
- # Calculate positive boost from matching attributes
427
- adjust = refined_metadata_adjustment(meta, query_attrs)
428
-
429
- # Final score: visual + metadata boost (no exclusion penalty needed)
430
- final_score = c["visual_score"] - alpha * adjust
431
-
432
- ranked.append({
433
- "image_id": c["image_id"],
434
- "visual_score": c["visual_score"],
435
- "metadata_boost": adjust,
436
- "final_score": final_score
437
- })
438
-
439
- return sorted(ranked, key=lambda x: x["final_score"])
440
-
441
-
442
- def rerank_with_cross_encoder(
443
- query: str,
444
- candidates: List[Dict],
445
- top_k: int = 12
446
- ) -> List[Dict]:
447
- """
448
- Re-rank candidates using cross-encoder for better semantic matching.
449
-
450
- Two-stage pipeline:
451
- 1. CLIP bi-encoder: Fast retrieval (already done)
452
- 2. Cross-encoder: Accurate semantic re-ranking
453
-
454
- Args:
455
- query: User query text
456
- candidates: List of {image_id, visual_score, metadata_boost, ...}
457
- top_k: Number of results to return
458
-
459
- Returns:
460
- Re-ranked list of top K candidates
461
- """
462
- if not candidates:
463
- return []
464
-
465
- # Prepare query-document pairs for cross-encoder
466
- pairs = []
467
- for c in candidates:
468
- # Get BLIP caption for this image
469
- caption = BLIP_CAPTIONS.get(c["image_id"], "")
470
-
471
- # Get metadata
472
- meta = metadata_collection.get(
473
- ids=[c["image_id"]],
474
- include=["metadatas"]
475
- )["metadatas"][0]
476
-
477
- # Create rich text representation combining caption + metadata
478
- doc_text = f"{caption}. Category: {meta.get('category', 'unknown')}, Metal: {meta.get('metal', 'unknown')}, Stone: {meta.get('primary_stone', 'unknown')}"
479
-
480
- pairs.append([query, doc_text])
481
-
482
- # Score all pairs with cross-encoder (batch processing)
483
- print(f"πŸ”„ Cross-encoder scoring {len(pairs)} candidates...")
484
- encoder = get_cross_encoder()
485
- cross_scores = encoder.predict(pairs, batch_size=32)
486
-
487
- # Combine scores: visual + metadata + cross-encoder
488
- for i, c in enumerate(candidates):
489
- c["cross_encoder_score"] = float(cross_scores[i])
490
-
491
- # Final score combines all signals
492
- # - Visual similarity (CLIP): 30% weight
493
- # - Metadata match: 20% weight
494
- # - Semantic similarity (cross-encoder): 50% weight (highest)
495
- c["final_score_reranked"] = (
496
- -c["visual_score"] * 0.3 + # Negate because lower distance = better
497
- c.get("metadata_boost", 0) * 0.2 +
498
- c["cross_encoder_score"] * 0.5
499
- )
500
-
501
- # Sort by final score (higher is better)
502
- ranked = sorted(candidates, key=lambda x: x["final_score_reranked"], reverse=True)
503
-
504
- print(f"βœ… Re-ranked {len(ranked)} candidates, returning top {top_k}")
505
- return ranked[:top_k]
506
-
507
-
508
- # %%
509
- # ============================================================
510
- # IMAGE UPLOAD & OCR HELPER FUNCTIONS
511
- # ============================================================
512
-
513
- def encode_uploaded_image(image_bytes: bytes) -> np.ndarray:
514
- """Encode uploaded image using CLIP model"""
515
- try:
516
- # Open image from bytes
517
- image = Image.open(io.BytesIO(image_bytes))
518
-
519
- # Convert to RGB if necessary
520
- if image.mode != 'RGB':
521
- image = image.convert('RGB')
522
-
523
- # Resize if too large (max 512x512 for efficiency)
524
- max_size = 512
525
- if max(image.size) > max_size:
526
- image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
527
-
528
- # Preprocess for CLIP
529
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
530
-
531
- preprocess = Compose([
532
- Resize(224, interpolation=Image.BICUBIC),
533
- CenterCrop(224),
534
- ToTensor(),
535
- Normalize((0.48145466, 0.4578275, 0.40821073),
536
- (0.26862954, 0.26130258, 0.27577711))
537
- ])
538
-
539
- image_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
540
-
541
- # Encode with CLIP
542
- model = get_clip_model()
543
- with torch.no_grad():
544
- image_features = model.encode_image(image_tensor)
545
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
546
- result = image_features.cpu().numpy()[0]
547
-
548
- # Memory cleanup
549
- del image_tensor, image_features
550
- if DEVICE == "cuda":
551
- torch.cuda.empty_cache()
552
-
553
- return result
554
-
555
- except Exception as e:
556
- raise HTTPException(status_code=400, detail=f"Failed to process image: {str(e)}")
557
-
558
-
559
- def extract_text_from_image(image_bytes: bytes) -> str:
560
- """Extract text from image using NVIDIA NeMo Retriever OCR API with GPT-4.1-Nano fallback"""
561
-
562
- # Try NVIDIA OCR first if key is configured
563
- extracted_text = ""
564
- nvidia_failed = False
565
-
566
- if NVIDIA_API_KEY:
567
- try:
568
- # Encode image to base64
569
- image_b64 = base64.b64encode(image_bytes).decode()
570
-
571
- # Prepare request
572
- headers = {
573
- "Authorization": f"Bearer {NVIDIA_API_KEY}",
574
- "Accept": "application/json",
575
- "Content-Type": "application/json"
576
- }
577
-
578
- payload = {
579
- "input": [
580
- {
581
- "type": "image_url",
582
- "url": f"data:image/png;base64,{image_b64}"
583
- }
584
- ]
585
- }
586
-
587
- # Call NVIDIA OCR API
588
- print(f"πŸ“ž Calling NVIDIA OCR API...")
589
- response = requests.post(
590
- NVIDIA_OCR_URL,
591
- headers=headers,
592
- json=payload,
593
- timeout=15 # Shorter timeout to fail fast
594
- )
595
-
596
- if response.status_code == 200:
597
- result = response.json()
598
-
599
- # Format 1: Text detections array
600
- if "data" in result and isinstance(result["data"], list) and len(result["data"]) > 0:
601
- for data_item in result["data"]:
602
- if isinstance(data_item, dict) and "text_detections" in data_item:
603
- for detection in data_item["text_detections"]:
604
- if "text_prediction" in detection and "text" in detection["text_prediction"]:
605
- extracted_text += detection["text_prediction"]["text"] + " "
606
- elif isinstance(data_item, dict) and "content" in data_item:
607
- extracted_text += data_item["content"] + " "
608
-
609
- # Format 2: Direct text field
610
- elif "text" in result:
611
- extracted_text = result["text"]
612
-
613
- # Format 3: Choices/Results
614
- elif "choices" in result and len(result["choices"]) > 0:
615
- if "text" in result["choices"][0]:
616
- extracted_text = result["choices"][0]["text"]
617
- elif "message" in result["choices"][0]:
618
- extracted_text = result["choices"][0]["message"].get("content", "")
619
-
620
- extracted_text = extracted_text.strip()
621
- if extracted_text:
622
- print(f"βœ… Extracted text (NVIDIA): '{extracted_text}'")
623
- return extracted_text
624
-
625
- print(f"⚠️ NVIDIA OCR failed with status {response.status_code}. Trying fallback...")
626
- nvidia_failed = True
627
-
628
- except Exception as e:
629
- print(f"⚠️ NVIDIA OCR exception: {e}. Trying fallback...")
630
- nvidia_failed = True
631
- else:
632
- print("ℹ️ NVIDIA_API_KEY not set. Using fallback directly.")
633
- nvidia_failed = True
634
-
635
- # FALLBACK: Custom GPT-4.1-Nano OCR (OpenAI Compatible)
636
- try:
637
- if not OPENAI_API_KEY:
638
- raise HTTPException(status_code=500, detail="OCR unavailable: Primary failed and OPENAI_API_KEY not set for fallback.")
639
-
640
- print("πŸ”„ Using Global GPT-4.1-Nano Fallback...")
641
-
642
- # Initialize OpenAI client with custom base URL
643
- from openai import OpenAI
644
-
645
- client = OpenAI(
646
- base_url="https://apidev.navigatelabsai.com/v1",
647
- api_key=OPENAI_API_KEY
648
- )
649
-
650
- image_b64 = base64.b64encode(image_bytes).decode()
651
-
652
- response = client.chat.completions.create(
653
- model="gpt-4.1-nano",
654
- messages=[
655
- {
656
- "role": "user",
657
- "content": [
658
- {"type": "text", "text": "Transcribe the handwritten text in this image exactly as it appears. Output ONLY the text, nothing else."},
659
- {
660
- "type": "image_url",
661
- "image_url": {
662
- "url": f"data:image/png;base64,{image_b64}"
663
- }
664
- }
665
- ]
666
- }
667
- ],
668
- max_tokens=300
669
- )
670
-
671
- extracted_text = response.choices[0].message.content.strip()
672
-
673
- if not extracted_text:
674
- raise HTTPException(status_code=400, detail="No readable text found in image (Fallback).")
675
-
676
- print(f"βœ… Extracted text (Fallback GPT): '{extracted_text}'")
677
- return extracted_text
678
-
679
- except HTTPException:
680
- raise
681
- except Exception as e:
682
- print(f"❌ Fallback OCR failed: {e}")
683
- raise HTTPException(status_code=500, detail=f"OCR failed permanently: {str(e)}")
684
-
685
-
686
- # %%
687
- # ============================================================
688
- # LLM-POWERED EXPLANATION (GROQ LLAMA 3.1) - BATCH PROCESSING
689
- # ============================================================
690
-
691
- def batch_generate_explanations(results: List[Dict], query_attrs: Dict, user_query: str) -> List[str]:
692
- """Generate diverse, LLM-powered explanations for all search results in ONE API call"""
693
-
694
- if not results:
695
- return []
696
-
697
- # Build context for all items (handle up to 20 items in one call)
698
- items_context = []
699
- for idx, r in enumerate(results, 1):
700
- meta = metadata_collection.get(
701
- ids=[r["image_id"]],
702
- include=["metadatas"]
703
- )["metadatas"][0]
704
-
705
- matched_attrs = [v for k, v in query_attrs.items() if meta.get(k) == v]
706
-
707
- # Compact format to save tokens
708
- item_info = f"{idx}. {meta.get('category', 'item')} | {meta.get('metal', '?')} | {meta.get('primary_stone', '?')} | score:{r['visual_score']:.2f} | matched:{','.join(matched_attrs) if matched_attrs else 'none'}"
709
- items_context.append(item_info)
710
-
711
- # Compact prompt to fit more items
712
- prompt = f"""Query: "{user_query}"
713
-
714
- Write 1 brief sentence for EACH item:
715
-
716
- {chr(10).join(items_context)}
717
-
718
- Format:
719
- 1. [sentence]
720
- 2. [sentence]
721
- etc."""
722
-
723
- if groq_client is None:
724
- explanations = []
725
- else:
726
- try:
727
- # Single API call for ALL items
728
- response = groq_client.chat.completions.create(
729
- model="llama-3.1-8b-instant",
730
- messages=[
731
- {"role": "system", "content": "Write brief jewellery recommendations."},
732
- {"role": "user", "content": prompt}
733
- ],
734
- temperature=0.7,
735
- max_tokens=min(800, len(results) * 60), # Increased for 12+ items
736
- top_p=0.9
737
- )
738
-
739
- # Parse response
740
- full_response = response.choices[0].message.content.strip()
741
- explanations = []
742
-
743
- import re
744
- pattern = r'^\s*(\d+)[\.:)\-]\s*(.+?)(?=^\s*\d+[\.:)\-]|\Z)'
745
- matches = re.findall(pattern, full_response, re.MULTILINE | re.DOTALL)
746
-
747
- if matches and len(matches) >= len(results):
748
- for num, text in matches[:len(results)]:
749
- clean_text = ' '.join(text.strip().split())
750
- if clean_text and len(clean_text) > 10:
751
- explanations.append(clean_text)
752
-
753
- if len(explanations) >= len(results):
754
- return explanations[:len(results)]
755
-
756
- # If incomplete, pad with fallback
757
- print(f"⚠️ LLM returned {len(explanations)}/{len(results)} explanations, padding with fallback")
758
-
759
- except Exception as e:
760
- print(f"⚠️ LLM explanation failed: {e}, using fallback")
761
- explanations = []
762
-
763
- # Fallback for missing explanations
764
- while len(explanations) < len(results):
765
- idx = len(explanations)
766
- r = results[idx]
767
- meta = metadata_collection.get(
768
- ids=[r["image_id"]],
769
- include=["metadatas"]
770
- )["metadatas"][0]
771
- matched_attrs = [v for k, v in query_attrs.items() if meta.get(k) == v]
772
-
773
- category = meta.get('category', 'item')
774
- metal = meta.get('metal', 'unknown')
775
- stone = meta.get('primary_stone', 'unknown')
776
-
777
- if matched_attrs and r['visual_score'] < 1.3:
778
- explanations.append(
779
- f"Excellent {category} featuring {' and '.join(matched_attrs)}. High visual similarity (score: {r['visual_score']:.2f})."
780
- )
781
- elif matched_attrs:
782
- explanations.append(
783
- f"Beautiful {metal} {category} with {stone}. Features {' and '.join(matched_attrs)}."
784
- )
785
- elif r['visual_score'] < 1.3:
786
- explanations.append(
787
- f"Highly similar {category} with excellent visual match. {metal.capitalize()} with {stone}."
788
- )
789
- else:
790
- explanations.append(
791
- f"Recommended {metal} {category} with {stone}. Good visual similarity."
792
- )
793
-
794
- return explanations
795
- # ============================================================
796
- # API ENDPOINTS
797
- # ============================================================
798
-
799
- @app.get("/health")
800
- def health_check():
801
- """Health check endpoint for HF Spaces monitoring"""
802
- return {
803
- "status": "healthy",
804
- "models_loaded": {
805
- "clip": clip_model is not None,
806
- "cross_encoder": cross_encoder is not None,
807
- "blip_captions": len(BLIP_CAPTIONS) > 0
808
- },
809
- "database": {
810
- "images": image_collection.count(),
811
- "metadata": metadata_collection.count()
812
- }
813
- }
814
-
815
- @app.post("/search/text")
816
- def search_text(req: TextSearchRequest):
817
- # Detect intent from text
818
- if req.query.strip():
819
- intent = detect_intent_and_attributes(req.query)
820
- attrs = intent["attributes"]
821
- else:
822
- intent = {"intent": "filter", "attributes": {}, "exclusions": {}}
823
- attrs = {}
824
-
825
- # === DUAL-STAGE FILTERING STRATEGY ===
826
- # 1. Identify valid IDs from Metadata Collection (Source of Truth)
827
- # 2. Use those IDs to filter Vector Search results
828
-
829
- # Construct WHERE clause for Metadata Collection
830
- where_clauses = []
831
-
832
- if req.filters:
833
- for key, value in req.filters.items():
834
- where_clauses.append({key: value.lower()}) # Explicit filters
835
-
836
- # Also apply attributes detected from text as generic filters if user didn't specify explicit ones
837
- # (Optional: this makes "emerald ring" implies primary_stone=emerald)
838
- # But usually we let visual search handle text unless it's strict.
839
-
840
- final_where = None
841
- if len(where_clauses) > 1:
842
- final_where = {"$and": where_clauses}
843
- elif len(where_clauses) == 1:
844
- final_where = where_clauses[0]
845
-
846
- valid_ids = None
847
- if final_where:
848
- # Fetch ALL valid IDs matching the filter
849
- print(f"πŸ” Filtering metadata with: {final_where}")
850
- meta_res = metadata_collection.get(where=final_where, include=["metadatas"])
851
- if meta_res["ids"]:
852
- valid_ids = set(meta_res["ids"])
853
- print(f"βœ… Found {len(valid_ids)} valid items matching filters.")
854
- else:
855
- print("⚠️ No items match the filters.")
856
- return {"query": req.query, "intent": attrs, "results": []}
857
-
858
- # === EXECUTE SEARCH ===
859
-
860
- # Case A: Filter Only (No Text Query)
861
- if not req.query.strip() and valid_ids:
862
- # Just return the matching items (Top K)
863
- candidates = [{"image_id": vid, "visual_score": 0.0} for vid in list(valid_ids)[:req.top_k]]
864
- ranked = candidates # No ranking needed without text
865
- explanations = ["Filtered result"] * len(ranked)
866
-
867
- # Case B: Text Query (with or without Filter)
868
- else:
869
- search_query = req.query if req.query.strip() else "jewellery"
870
-
871
- # We perform a BROADER vector search, then filter in Python
872
- # Retrieve K*5 or at least 100 to ensure we find intersections
873
- fetch_k = 200 if valid_ids else 40
874
-
875
- # Note: We do NOT pass 'where' to retrieve_visual_candidates because
876
- # image_collection lacks metadata. We filter manually.
877
- candidates = retrieve_visual_candidates(search_query, k=fetch_k)
878
-
879
- filtered_candidates = []
880
- for c in candidates:
881
- if valid_ids is not None:
882
- if c["image_id"] in valid_ids:
883
- filtered_candidates.append(c)
884
- else:
885
- filtered_candidates.append(c)
886
-
887
- # Apply strict limit now
888
- filtered = filtered_candidates # apply_metadata_boost(filtered_candidates, attrs, {})
889
-
890
- # Cross-encoder re-ranking
891
- if req.use_reranking and filtered and req.query.strip():
892
- ranked = rerank_with_cross_encoder(req.query, filtered, req.top_k)
893
- else:
894
- ranked = filtered[:req.top_k]
895
-
896
- # Explanations
897
- if req.use_explanations and req.query.strip():
898
- explanations = batch_generate_explanations(ranked, attrs, search_query)
899
- else:
900
- explanations = ["Match found"] * len(ranked)
901
-
902
- # === FORMAT RESULTS ===
903
- results = []
904
-
905
- # Fetch metadata for final results
906
- if ranked:
907
- ranked_ids = [r["image_id"] for r in ranked]
908
- metas = metadata_collection.get(ids=ranked_ids, include=["metadatas"])["metadatas"]
909
- meta_map = {rid: m for rid, m in zip(ranked_ids, metas)}
910
- else:
911
- meta_map = {}
912
-
913
- for r, explanation in zip(ranked, explanations):
914
- results.append({
915
- "image_id": r["image_id"],
916
- "explanation": explanation,
917
- "metadata": meta_map.get(r["image_id"], {}),
918
- "scores": {
919
- "visual": r["visual_score"],
920
- "final": r.get("visual_score", 0) # simplified
921
- }
922
- })
923
-
924
- return {
925
- "query": req.query,
926
- "intent": attrs,
927
- "results": results
928
- }
929
-
930
- return {
931
- "query": req.query,
932
- "intent": attrs,
933
- "results": results
934
- }
935
-
936
- # %%
937
- @app.post("/search/similar")
938
- def search_similar(req: SimilarSearchRequest):
939
- base = image_collection.get(
940
- ids=[req.image_id],
941
- include=["embeddings"]
942
- )["embeddings"][0]
943
-
944
- res = image_collection.query(
945
- query_embeddings=[base],
946
- n_results=req.top_k + 1
947
- )
948
-
949
- base_meta = metadata_collection.get(
950
- ids=[req.image_id],
951
- include=["metadatas"]
952
- )["metadatas"][0]
953
-
954
- attrs = {
955
- k: base_meta[k]
956
- for k in ["category", "metal", "primary_stone"]
957
- if base_meta.get(k) != "unknown"
958
- }
959
-
960
- candidates = [
961
- {
962
- "image_id": img_id,
963
- "visual_score": dist
964
- }
965
- for img_id, dist in zip(res["ids"][0], res["distances"][0])
966
- if img_id != req.image_id
967
- ]
968
-
969
- ranked = apply_metadata_boost(candidates, attrs, {})[:req.top_k]
970
-
971
- # Generate all explanations in one batch LLM call
972
- # For similar search, use the base image ID as the query context
973
- query_context = f"items similar to {req.image_id}"
974
- explanations = batch_generate_explanations(ranked, attrs, query_context)
975
-
976
- results = []
977
- for r, explanation in zip(ranked, explanations):
978
- results.append({
979
- "image_id": r["image_id"],
980
- "explanation": explanation,
981
- "scores": {
982
- "visual": r["visual_score"],
983
- "metadata": r["metadata_boost"],
984
- "final": r["final_score"]
985
- }
986
- })
987
-
988
- return {
989
- "base_image": req.image_id,
990
- "results": results
991
- }
992
-
993
- # %%
994
- # ============================================================
995
- # IMAGE UPLOAD SEARCH ENDPOINT
996
- # ============================================================
997
-
998
- @app.post("/search/upload-image")
999
- async def search_by_uploaded_image(
1000
- file: UploadFile = File(...),
1001
- top_k: int = 12
1002
- ):
1003
- """
1004
- Search for similar jewellery items by uploading an image.
1005
- The image is encoded using CLIP and queried against the database.
1006
- """
1007
- # Validate file type
1008
- if not file.content_type or not file.content_type.startswith('image/'):
1009
- raise HTTPException(status_code=400, detail="File must be an image")
1010
-
1011
- try:
1012
- # Read image bytes
1013
- image_bytes = await file.read()
1014
-
1015
- # Encode image with CLIP
1016
- query_embedding = encode_uploaded_image(image_bytes)
1017
-
1018
- # Query ChromaDB
1019
- res = image_collection.query(
1020
- query_embeddings=[query_embedding.tolist()],
1021
- n_results=min(100, top_k * 10),
1022
- include=["distances"]
1023
- )
1024
-
1025
- # Get metadata for all results
1026
- candidates = []
1027
- for img_id, dist in zip(res["ids"][0], res["distances"][0]):
1028
- candidates.append({
1029
- "image_id": img_id,
1030
- "visual_score": dist
1031
- })
1032
-
1033
- # Get metadata from first result to infer attributes
1034
- if candidates:
1035
- base_meta = metadata_collection.get(
1036
- ids=[candidates[0]["image_id"]],
1037
- include=["metadatas"]
1038
- )["metadatas"][0]
1039
-
1040
- attrs = {
1041
- k: base_meta[k]
1042
- for k in ["category", "metal", "primary_stone"]
1043
- if base_meta.get(k) != "unknown"
1044
- }
1045
- else:
1046
- attrs = {}
1047
-
1048
- # Apply metadata boost (no exclusions for image upload)
1049
- ranked = apply_metadata_boost(candidates, attrs, {})[:top_k]
1050
-
1051
- # Generate explanations in batch
1052
- query_context = f"items visually similar to uploaded image"
1053
- explanations = batch_generate_explanations(ranked, attrs, query_context)
1054
-
1055
- results = []
1056
- for r, explanation in zip(ranked, explanations):
1057
- results.append({
1058
- "image_id": r["image_id"],
1059
- "explanation": explanation,
1060
- "scores": {
1061
- "visual": r["visual_score"],
1062
- "metadata": r["metadata_boost"],
1063
- "final": r["final_score"]
1064
- }
1065
- })
1066
-
1067
- return {
1068
- "query_type": "uploaded_image",
1069
- "filename": file.filename,
1070
- "results": results
1071
- }
1072
-
1073
- except HTTPException:
1074
- raise
1075
- except Exception as e:
1076
- raise HTTPException(status_code=500, detail=f"Image search failed: {str(e)}")
1077
-
1078
-
1079
- # %%
1080
- # ============================================================
1081
- # OCR QUERY SEARCH ENDPOINT
1082
- # ============================================================
1083
-
1084
- @app.post("/search/ocr-query")
1085
- async def search_by_ocr_query(
1086
- file: UploadFile = File(...),
1087
- top_k: int = 12
1088
- ):
1089
- """
1090
- Extract text from uploaded image using NVIDIA NeMo OCR,
1091
- then perform text-based search with the extracted query.
1092
- """
1093
- # Validate file type
1094
- if not file.content_type or not file.content_type.startswith('image/'):
1095
- raise HTTPException(status_code=400, detail="File must be an image")
1096
-
1097
- try:
1098
- # Read image bytes
1099
- image_bytes = await file.read()
1100
-
1101
- # Extract text using NVIDIA OCR
1102
- extracted_text = extract_text_from_image(image_bytes)
1103
-
1104
- print(f"πŸ“ Extracted text from image: '{extracted_text}'")
1105
-
1106
- # Use the extracted text for normal text search
1107
- intent = detect_intent_and_attributes(extracted_text)
1108
- attrs = intent["attributes"]
1109
- exclusions = intent.get("exclusions", {})
1110
-
1111
- # Stage 1: CLIP retrieval (reduced to k=40 for HF Spaces)
1112
- candidates = retrieve_visual_candidates(extracted_text, k=40)
1113
-
1114
- # Stage 2: Metadata boost + exclusion filtering
1115
- filtered = apply_metadata_boost(candidates, attrs, exclusions)
1116
-
1117
- # Stage 3: Cross-encoder re-ranking
1118
- ranked = rerank_with_cross_encoder(extracted_text, filtered, top_k)
1119
-
1120
- # Generate explanations in batch
1121
- explanations = batch_generate_explanations(ranked, attrs, extracted_text)
1122
-
1123
- results = []
1124
- for r, explanation in zip(ranked, explanations):
1125
- results.append({
1126
- "image_id": r["image_id"],
1127
- "explanation": explanation,
1128
- "scores": {
1129
- "visual": r["visual_score"],
1130
- "metadata": r["metadata_boost"],
1131
- "final": r["final_score"]
1132
- }
1133
- })
1134
-
1135
- return {
1136
- "query_type": "ocr_extracted",
1137
- "extracted_text": extracted_text,
1138
- "intent": attrs,
1139
- "results": results
1140
- }
1141
-
1142
- except HTTPException:
1143
- raise
1144
- except Exception as e:
1145
- raise HTTPException(status_code=500, detail=f"OCR search failed: {str(e)}")
1146
-
1147
- # %%
1148
- @app.get("/image/{image_id}")
1149
- def get_image(image_id: str):
1150
- path = os.path.join(IMAGE_DIR, image_id)
1151
- if not os.path.exists(path):
1152
- raise HTTPException(status_code=404, detail="Image not found")
1153
- return FileResponse(path)
1154
-
1155
- # %%
1156
- # ============================================================
1157
- # RUN SERVER
1158
- # ============================================================
1159
-
1160
- if __name__ == "__main__":
1161
- import uvicorn
1162
- print("πŸš€ Starting Jewellery Search API server...")
1163
- print(f"πŸ“ Data directory: {DATA_DIR}")
1164
- print(f"πŸ“ Image directory: {IMAGE_DIR}")
1165
- print(f"πŸ“ ChromaDB path: {CHROMA_PATH}")
1166
- print(f"🌐 Server will run on: http://localhost:8000")
1167
- print(f"πŸ“– API docs available at: http://localhost:8000/docs")
1168
- uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
 
 
1
+ # %%
2
+ # ============================================================
3
+ # JEWELLERY MULTIMODAL SEARCH BACKEND (FASTAPI)
4
+ # ============================================================
5
+
6
+ # %%
7
+ # ============================================================
8
+ # IMPORTS
9
+ # ============================================================
10
+
11
+ import os
12
+ import json
13
+ from typing import List, Dict
14
+
15
+ import torch
16
+ import clip
17
+ import numpy as np
18
+ import chromadb
19
+
20
+ from fastapi import FastAPI, HTTPException, File, UploadFile
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+ from fastapi.responses import FileResponse
23
+ from pydantic import BaseModel
24
+
25
+ from openai import OpenAI
26
+ from dotenv import load_dotenv
27
+ from sentence_transformers import CrossEncoder
28
+
29
+ import base64
30
+ import requests
31
+ from PIL import Image
32
+ import io
33
+
34
+ # Load environment variables from .env file
35
+ load_dotenv()
36
+
37
+ # %%
38
+ # ============================================================
39
+ # CONFIG
40
+ # ============================================================
41
+
42
+ # Use absolute paths for deployment
43
+ # Priority: /app (Docker/HF Spaces) > script directory (local)
44
+ if os.path.exists("/app"):
45
+ # Running in Docker (Hugging Face Spaces)
46
+ BASE_DIR = "/app"
47
+ else:
48
+ # Running locally - use absolute path to script directory
49
+ BASE_DIR = os.path.abspath(os.path.dirname(__file__))
50
+
51
+ CHROMA_PATH = os.path.join(BASE_DIR, "chroma_primary")
52
+ DATA_DIR = os.path.join(BASE_DIR, "data", "tanishq")
53
+ IMAGE_DIR = os.path.join(DATA_DIR, "images")
54
+ BLIP_CAPTIONS_PATH = os.path.join(DATA_DIR, "blip_captions.json")
55
+
56
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ # %%
59
+ # ============================================================
60
+ # LAZY MODEL LOADING (Reduces cold start time)
61
+ # ============================================================
62
+
63
+ # Global model references (loaded on first use)
64
+ clip_model = None
65
+ cross_encoder = None
66
+
67
+ def get_clip_model():
68
+ """Lazy load CLIP model on first use"""
69
+ global clip_model
70
+ if clip_model is None:
71
+ print("πŸ”Ή Loading CLIP model...")
72
+ model, _ = clip.load("ViT-B/16", device=DEVICE)
73
+ model.eval()
74
+ clip_model = model
75
+ print("βœ… CLIP model loaded")
76
+ return clip_model
77
+
78
+ def get_cross_encoder():
79
+ """Lazy load Cross-Encoder on first use"""
80
+ global cross_encoder
81
+ if cross_encoder is None:
82
+ print("πŸ”Ή Loading Cross-Encoder for re-ranking...")
83
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
84
+ print("βœ… Cross-Encoder loaded")
85
+ return cross_encoder
86
+
87
+ # %%
88
+ print("πŸ”Ή Loading BLIP captions...")
89
+ with open(BLIP_CAPTIONS_PATH, "r") as f:
90
+ BLIP_CAPTIONS = json.load(f)
91
+
92
+ # %%
93
+ # ============================================================
94
+ # INITIALIZE GROQ LLM CLIENT
95
+ # ============================================================
96
+
97
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
98
+ if GROQ_API_KEY:
99
+ print("πŸ”Ή Initializing Groq LLM client...")
100
+ groq_client = OpenAI(
101
+ base_url="https://api.groq.com/openai/v1",
102
+ api_key=GROQ_API_KEY
103
+ )
104
+ else:
105
+ groq_client = None
106
+ print("⚠️ GROQ_API_KEY not set; LLM features disabled (fallbacks enabled)")
107
+
108
+ # NVIDIA OCR API configuration
109
+ NVIDIA_API_KEY = os.environ.get("NVIDIA_API_KEY")
110
+ NVIDIA_OCR_URL = "https://ai.api.nvidia.com/v1/cv/nvidia/nemoretriever-ocr-v1"
111
+
112
+ # Fallback OCR configuration
113
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
114
+
115
+ # %%
116
+ # ============================================================
117
+ # LOAD CHROMA (PERSISTED DB)
118
+ # ============================================================
119
+
120
+ print("πŸ”Ή Connecting to Chroma DB...")
121
+ chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
122
+
123
+ image_collection = chroma_client.get_collection("jewelry_images")
124
+ metadata_collection = chroma_client.get_collection("jewelry_metadata")
125
+
126
+ print(
127
+ "βœ… Chroma loaded | Images:",
128
+ image_collection.count(),
129
+ "| Metadata:",
130
+ metadata_collection.count()
131
+ )
132
+
133
+ # %%
134
+ # ============================================================
135
+ # FASTAPI APP
136
+ # ============================================================
137
+
138
+ app = FastAPI(title="Jewellery Multimodal Search")
139
+
140
+ app.add_middleware(
141
+ CORSMiddleware,
142
+ allow_origins=[
143
+ "http://localhost:5173",
144
+ "https://tanishq-rag-capstone-1lj2x4y1v-akash-aimls-projects.vercel.app", # Vercel preview
145
+ "https://*.vercel.app", # All Vercel deployments
146
+ "https://*.ngrok-free.app", # Allow ngrok tunnels
147
+ "https://*.ngrok-free.dev", # Allow ngrok tunnels (new domain)
148
+ "https://*.ngrok.io", # Allow ngrok tunnels (legacy)
149
+ ],
150
+ allow_credentials=True,
151
+ allow_methods=["*"],
152
+ allow_headers=["*"],
153
+ )
154
+
155
+ # %%
156
+ # ============================================================
157
+ # MIDDLEWARE FOR HF SPACES OPTIMIZATION
158
+ # ============================================================
159
+
160
+ import asyncio
161
+ from starlette.requests import Request
162
+
163
+ @app.middleware("http")
164
+ async def add_optimizations(request: Request, call_next):
165
+ """Add upload size limits and request timeouts"""
166
+
167
+ # Limit upload size to 5MB
168
+ if request.method == "POST":
169
+ content_length = request.headers.get("content-length")
170
+ if content_length and int(content_length) > 5 * 1024 * 1024: # 5MB
171
+ raise HTTPException(status_code=413, detail="File too large (max 5MB)")
172
+
173
+ # Add request timeout (60s for local dev/slower machines)
174
+ try:
175
+ response = await asyncio.wait_for(call_next(request), timeout=60.0)
176
+ return response
177
+ except asyncio.TimeoutError:
178
+ raise HTTPException(status_code=504, detail="Request timeout (max 60s)")
179
+
180
+ # %%
181
+ # ============================================================
182
+ # REQUEST / RESPONSE SCHEMAS
183
+ # ============================================================
184
+
185
+ class TextSearchRequest(BaseModel):
186
+ query: str
187
+ filters: Dict[str, str] = None # Explicit UI filters (e.g. {"metal": "gold"})
188
+ top_k: int = 5
189
+ use_reranking: bool = True # Toggle cross-encoder (3x faster when False)
190
+ use_explanations: bool = True # Toggle LLM explanations (500ms+ faster when False)
191
+
192
+
193
+ class SimilarSearchRequest(BaseModel):
194
+ image_id: str
195
+ top_k: int = 5
196
+
197
+ # %%
198
+ # ============================================================
199
+ # CLIP QUERY ENCODING (TEXT ONLY)
200
+ # ============================================================
201
+
202
+ def encode_text_clip(text: str) -> np.ndarray:
203
+ """Encode text using CLIP with memory cleanup"""
204
+ model = get_clip_model()
205
+ tokens = clip.tokenize([text]).to(DEVICE)
206
+ with torch.no_grad():
207
+ emb = model.encode_text(tokens)
208
+ emb = emb / emb.norm(dim=-1, keepdim=True)
209
+ result = emb.cpu().numpy()[0]
210
+
211
+ # Memory cleanup
212
+ del tokens, emb
213
+ if DEVICE == "cuda":
214
+ torch.cuda.empty_cache()
215
+
216
+ return result
217
+
218
+ # %%
219
+ # ============================================================
220
+ # INTENT & ATTRIBUTE DETECTION WITH LLM (STRUCTURED)
221
+ # ============================================================
222
+
223
+ def detect_intent_and_attributes(query: str) -> Dict:
224
+ """
225
+ Extract search attributes and exclusions from query using LLM with fixed schema.
226
+
227
+ Returns:
228
+ {
229
+ "intent": "search",
230
+ "attributes": {category, metal, primary_stone}, # Items to INCLUDE
231
+ "exclusions": {category, metal, primary_stone} # Items to EXCLUDE
232
+ }
233
+ """
234
+
235
+ prompt = f"""Extract jewellery search attributes from this query.
236
+
237
+ Query: "{query}"
238
+
239
+ Return ONLY valid JSON with this exact schema:
240
+ {{
241
+ "intent": "search",
242
+ "attributes": {{
243
+ "category": "ring|necklace|earring|bracelet|null",
244
+ "metal": "gold|silver|platinum|null",
245
+ "primary_stone": "diamond|pearl|ruby|emerald|sapphire|null"
246
+ }},
247
+ "exclusions": {{
248
+ "category": "ring|necklace|earring|bracelet|null",
249
+ "metal": "gold|silver|platinum|null",
250
+ "primary_stone": "diamond|pearl|ruby|emerald|sapphire|null"
251
+ }}
252
+ }}
253
+
254
+ Rules:
255
+ - "attributes" = what to INCLUDE (positive filters)
256
+ - "exclusions" = what to EXCLUDE (negative filters)
257
+ - Use null for unspecified fields
258
+ - Detect negations: "no", "without", "not", "plain", "-free"
259
+
260
+ Examples:
261
+
262
+ Query: "gold ring with diamonds"
263
+ {{"intent": "search", "attributes": {{"category": "ring", "metal": "gold", "primary_stone": "diamond"}}, "exclusions": {{"category": null, "metal": null, "primary_stone": null}}}}
264
+
265
+ Query: "ring with no diamonds"
266
+ {{"intent": "search", "attributes": {{"category": "ring", "metal": null, "primary_stone": null}}, "exclusions": {{"category": null, "metal": null, "primary_stone": "diamond"}}}}
267
+
268
+ Query: "plain silver necklace"
269
+ {{"intent": "search", "attributes": {{"category": "necklace", "metal": "silver", "primary_stone": null}}, "exclusions": {{"category": null, "metal": null, "primary_stone": "any"}}}}
270
+
271
+ Query: "gold necklace without pearls"
272
+ {{"intent": "search", "attributes": {{"category": "necklace", "metal": "gold", "primary_stone": null}}, "exclusions": {{"category": null, "metal": null, "primary_stone": "pearl"}}}}
273
+
274
+ Return ONLY the JSON, no explanation."""
275
+
276
+ def simple_fallback() -> Dict:
277
+ q = query.lower()
278
+ attrs = {}
279
+
280
+ if "necklace" in q:
281
+ attrs["category"] = "necklace"
282
+ elif "ring" in q:
283
+ attrs["category"] = "ring"
284
+
285
+ if "gold" in q:
286
+ attrs["metal"] = "gold"
287
+ elif "silver" in q:
288
+ attrs["metal"] = "silver"
289
+
290
+ if "pearl" in q:
291
+ attrs["primary_stone"] = "pearl"
292
+ elif "diamond" in q:
293
+ attrs["primary_stone"] = "diamond"
294
+
295
+ return {
296
+ "intent": "search",
297
+ "attributes": attrs,
298
+ "exclusions": {}
299
+ }
300
+
301
+ if groq_client is None:
302
+ return simple_fallback()
303
+
304
+ try:
305
+ response = groq_client.chat.completions.create(
306
+ model="llama-3.3-70b-versatile",
307
+ messages=[{"role": "user", "content": prompt}],
308
+ temperature=0.1,
309
+ max_tokens=300
310
+ )
311
+
312
+ result_text = response.choices[0].message.content.strip()
313
+
314
+ # Extract JSON from response (handle markdown code blocks)
315
+ if "```json" in result_text:
316
+ result_text = result_text.split("```json")[1].split("```")[0].strip()
317
+ elif "```" in result_text:
318
+ result_text = result_text.split("```")[1].split("```")[0].strip()
319
+
320
+ result = json.loads(result_text)
321
+
322
+ # Clean null values
323
+ result["attributes"] = {k: v for k, v in result.get("attributes", {}).items() if v and v != "null"}
324
+ result["exclusions"] = {k: v for k, v in result.get("exclusions", {}).items() if v and v != "null"}
325
+
326
+ return result
327
+
328
+ except Exception as e:
329
+ print(f"⚠️ LLM extraction failed: {e}, falling back to simple extraction")
330
+ return simple_fallback()
331
+
332
+
333
+ # %%
334
+ # ============================================================
335
+ # VISUAL RETRIEVAL (NO LANGCHAIN)
336
+ # ============================================================
337
+
338
+ def retrieve_visual_candidates(query_text: str, k: int = 100, where_filter: Dict = None):
339
+ q_emb = encode_text_clip(query_text)
340
+
341
+ # Use Chroma's built-in filtering if provided
342
+ res = image_collection.query(
343
+ query_embeddings=[q_emb],
344
+ n_results=k,
345
+ where=where_filter
346
+ )
347
+
348
+ if not res["ids"] or not res["ids"][0]:
349
+ return []
350
+
351
+ return [
352
+ {
353
+ "image_id": img_id,
354
+ "visual_score": dist
355
+ }
356
+ for img_id, dist in zip(res["ids"][0], res["distances"][0])
357
+ ]
358
+
359
+ # %%
360
+ # ============================================================
361
+ # METADATA SCORING REFINEMENTS
362
+ # ============================================================
363
+
364
+ def adaptive_alpha(query_attrs: Dict) -> float:
365
+ return 0.1 + 0.1 * len(query_attrs)
366
+
367
+
368
+ def refined_metadata_adjustment(meta: Dict, query_attrs: Dict) -> float:
369
+ score = 0.0
370
+
371
+ for attr, q_val in query_attrs.items():
372
+ m_val = meta.get(attr)
373
+ conf = meta.get(f"confidence_{attr}", 0.0)
374
+
375
+ if m_val == q_val:
376
+ score += conf
377
+ elif conf > 0.6:
378
+ score -= 0.3 * conf
379
+
380
+ return score
381
+
382
+
383
+ def apply_metadata_boost(candidates: List[Dict], query_attrs: Dict, exclusions: Dict = None):
384
+ """
385
+ Rank candidates by combining visual similarity with metadata matching.
386
+ HARD FILTER out excluded items completely.
387
+
388
+ Args:
389
+ candidates: List of {image_id, visual_score}
390
+ query_attrs: Attributes to INCLUDE (boost matching items)
391
+ exclusions: Attributes to EXCLUDE (FILTER OUT completely)
392
+ """
393
+ if exclusions is None:
394
+ exclusions = {}
395
+
396
+ alpha = adaptive_alpha(query_attrs)
397
+ ranked = []
398
+
399
+ for c in candidates:
400
+ meta = metadata_collection.get(
401
+ ids=[c["image_id"]],
402
+ include=["metadatas"]
403
+ )["metadatas"][0]
404
+
405
+ # HARD FILTER: Skip items that match exclusions
406
+ should_exclude = False
407
+ for attr, excluded_value in exclusions.items():
408
+ meta_value = meta.get(attr)
409
+
410
+ # Handle "any" exclusion (e.g., "plain" means no stones at all)
411
+ if excluded_value == "any":
412
+ # Exclude if has ANY stone (not unknown/null)
413
+ if meta_value and meta_value not in ["unknown", "null", ""]:
414
+ should_exclude = True
415
+ print(f"🚫 Excluding {c['image_id']}: has {attr}={meta_value} (want none)")
416
+ break
417
+ # Handle specific exclusion
418
+ elif meta_value == excluded_value:
419
+ should_exclude = True
420
+ print(f"🚫 Excluding {c['image_id']}: has {attr}={meta_value} (excluded)")
421
+ break
422
+
423
+ # Skip this item if it matches any exclusion
424
+ if should_exclude:
425
+ continue
426
+
427
+ # Calculate positive boost from matching attributes
428
+ adjust = refined_metadata_adjustment(meta, query_attrs)
429
+
430
+ # Final score: visual + metadata boost (no exclusion penalty needed)
431
+ final_score = c["visual_score"] - alpha * adjust
432
+
433
+ ranked.append({
434
+ "image_id": c["image_id"],
435
+ "visual_score": c["visual_score"],
436
+ "metadata_boost": adjust,
437
+ "final_score": final_score
438
+ })
439
+
440
+ return sorted(ranked, key=lambda x: x["final_score"])
441
+
442
+
443
+ def rerank_with_cross_encoder(
444
+ query: str,
445
+ candidates: List[Dict],
446
+ top_k: int = 12
447
+ ) -> List[Dict]:
448
+ """
449
+ Re-rank candidates using cross-encoder for better semantic matching.
450
+
451
+ Two-stage pipeline:
452
+ 1. CLIP bi-encoder: Fast retrieval (already done)
453
+ 2. Cross-encoder: Accurate semantic re-ranking
454
+
455
+ Args:
456
+ query: User query text
457
+ candidates: List of {image_id, visual_score, metadata_boost, ...}
458
+ top_k: Number of results to return
459
+
460
+ Returns:
461
+ Re-ranked list of top K candidates
462
+ """
463
+ if not candidates:
464
+ return []
465
+
466
+ # Prepare query-document pairs for cross-encoder
467
+ pairs = []
468
+ for c in candidates:
469
+ # Get BLIP caption for this image
470
+ caption = BLIP_CAPTIONS.get(c["image_id"], "")
471
+
472
+ # Get metadata
473
+ meta = metadata_collection.get(
474
+ ids=[c["image_id"]],
475
+ include=["metadatas"]
476
+ )["metadatas"][0]
477
+
478
+ # Create rich text representation combining caption + metadata
479
+ doc_text = f"{caption}. Category: {meta.get('category', 'unknown')}, Metal: {meta.get('metal', 'unknown')}, Stone: {meta.get('primary_stone', 'unknown')}"
480
+
481
+ pairs.append([query, doc_text])
482
+
483
+ # Score all pairs with cross-encoder (batch processing)
484
+ print(f"πŸ”„ Cross-encoder scoring {len(pairs)} candidates...")
485
+ encoder = get_cross_encoder()
486
+ cross_scores = encoder.predict(pairs, batch_size=32)
487
+
488
+ # Combine scores: visual + metadata + cross-encoder
489
+ for i, c in enumerate(candidates):
490
+ c["cross_encoder_score"] = float(cross_scores[i])
491
+
492
+ # Final score combines all signals
493
+ # - Visual similarity (CLIP): 30% weight
494
+ # - Metadata match: 20% weight
495
+ # - Semantic similarity (cross-encoder): 50% weight (highest)
496
+ c["final_score_reranked"] = (
497
+ -c["visual_score"] * 0.3 + # Negate because lower distance = better
498
+ c.get("metadata_boost", 0) * 0.2 +
499
+ c["cross_encoder_score"] * 0.5
500
+ )
501
+
502
+ # Sort by final score (higher is better)
503
+ ranked = sorted(candidates, key=lambda x: x["final_score_reranked"], reverse=True)
504
+
505
+ print(f"βœ… Re-ranked {len(ranked)} candidates, returning top {top_k}")
506
+ return ranked[:top_k]
507
+
508
+
509
+ # %%
510
+ # ============================================================
511
+ # IMAGE UPLOAD & OCR HELPER FUNCTIONS
512
+ # ============================================================
513
+
514
+ def encode_uploaded_image(image_bytes: bytes) -> np.ndarray:
515
+ """Encode uploaded image using CLIP model"""
516
+ try:
517
+ # Open image from bytes
518
+ image = Image.open(io.BytesIO(image_bytes))
519
+
520
+ # Convert to RGB if necessary
521
+ if image.mode != 'RGB':
522
+ image = image.convert('RGB')
523
+
524
+ # Resize if too large (max 512x512 for efficiency)
525
+ max_size = 512
526
+ if max(image.size) > max_size:
527
+ image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
528
+
529
+ # Preprocess for CLIP
530
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
531
+
532
+ preprocess = Compose([
533
+ Resize(224, interpolation=Image.BICUBIC),
534
+ CenterCrop(224),
535
+ ToTensor(),
536
+ Normalize((0.48145466, 0.4578275, 0.40821073),
537
+ (0.26862954, 0.26130258, 0.27577711))
538
+ ])
539
+
540
+ image_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
541
+
542
+ # Encode with CLIP
543
+ model = get_clip_model()
544
+ with torch.no_grad():
545
+ image_features = model.encode_image(image_tensor)
546
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
547
+ result = image_features.cpu().numpy()[0]
548
+
549
+ # Memory cleanup
550
+ del image_tensor, image_features
551
+ if DEVICE == "cuda":
552
+ torch.cuda.empty_cache()
553
+
554
+ return result
555
+
556
+ except Exception as e:
557
+ raise HTTPException(status_code=400, detail=f"Failed to process image: {str(e)}")
558
+
559
+
560
+ def extract_text_from_image(image_bytes: bytes) -> str:
561
+ """Extract text from image using NVIDIA NeMo Retriever OCR API with GPT-4.1-Nano fallback"""
562
+
563
+ # Try NVIDIA OCR first if key is configured
564
+ extracted_text = ""
565
+ nvidia_failed = False
566
+
567
+ if NVIDIA_API_KEY:
568
+ try:
569
+ # Encode image to base64
570
+ image_b64 = base64.b64encode(image_bytes).decode()
571
+
572
+ # Prepare request
573
+ headers = {
574
+ "Authorization": f"Bearer {NVIDIA_API_KEY}",
575
+ "Accept": "application/json",
576
+ "Content-Type": "application/json"
577
+ }
578
+
579
+ payload = {
580
+ "input": [
581
+ {
582
+ "type": "image_url",
583
+ "url": f"data:image/png;base64,{image_b64}"
584
+ }
585
+ ]
586
+ }
587
+
588
+ # Call NVIDIA OCR API
589
+ print(f"πŸ“ž Calling NVIDIA OCR API...")
590
+ response = requests.post(
591
+ NVIDIA_OCR_URL,
592
+ headers=headers,
593
+ json=payload,
594
+ timeout=15 # Shorter timeout to fail fast
595
+ )
596
+
597
+ if response.status_code == 200:
598
+ result = response.json()
599
+
600
+ # Format 1: Text detections array
601
+ if "data" in result and isinstance(result["data"], list) and len(result["data"]) > 0:
602
+ for data_item in result["data"]:
603
+ if isinstance(data_item, dict) and "text_detections" in data_item:
604
+ for detection in data_item["text_detections"]:
605
+ if "text_prediction" in detection and "text" in detection["text_prediction"]:
606
+ extracted_text += detection["text_prediction"]["text"] + " "
607
+ elif isinstance(data_item, dict) and "content" in data_item:
608
+ extracted_text += data_item["content"] + " "
609
+
610
+ # Format 2: Direct text field
611
+ elif "text" in result:
612
+ extracted_text = result["text"]
613
+
614
+ # Format 3: Choices/Results
615
+ elif "choices" in result and len(result["choices"]) > 0:
616
+ if "text" in result["choices"][0]:
617
+ extracted_text = result["choices"][0]["text"]
618
+ elif "message" in result["choices"][0]:
619
+ extracted_text = result["choices"][0]["message"].get("content", "")
620
+
621
+ extracted_text = extracted_text.strip()
622
+ if extracted_text:
623
+ print(f"βœ… Extracted text (NVIDIA): '{extracted_text}'")
624
+ return extracted_text
625
+
626
+ print(f"⚠️ NVIDIA OCR failed with status {response.status_code}. Trying fallback...")
627
+ nvidia_failed = True
628
+
629
+ except Exception as e:
630
+ print(f"⚠️ NVIDIA OCR exception: {e}. Trying fallback...")
631
+ nvidia_failed = True
632
+ else:
633
+ print("ℹ️ NVIDIA_API_KEY not set. Using fallback directly.")
634
+ nvidia_failed = True
635
+
636
+ # FALLBACK: Custom GPT-4.1-Nano OCR (OpenAI Compatible)
637
+ try:
638
+ if not OPENAI_API_KEY:
639
+ raise HTTPException(status_code=500, detail="OCR unavailable: Primary failed and OPENAI_API_KEY not set for fallback.")
640
+
641
+ print("πŸ”„ Using Global GPT-4.1-Nano Fallback...")
642
+
643
+ # Initialize OpenAI client with custom base URL
644
+ from openai import OpenAI
645
+
646
+ client = OpenAI(
647
+ base_url="https://apidev.navigatelabsai.com/v1",
648
+ api_key=OPENAI_API_KEY
649
+ )
650
+
651
+ image_b64 = base64.b64encode(image_bytes).decode()
652
+
653
+ response = client.chat.completions.create(
654
+ model="gpt-4.1-nano",
655
+ messages=[
656
+ {
657
+ "role": "user",
658
+ "content": [
659
+ {"type": "text", "text": "Transcribe the handwritten text in this image exactly as it appears. Output ONLY the text, nothing else."},
660
+ {
661
+ "type": "image_url",
662
+ "image_url": {
663
+ "url": f"data:image/png;base64,{image_b64}"
664
+ }
665
+ }
666
+ ]
667
+ }
668
+ ],
669
+ max_tokens=300
670
+ )
671
+
672
+ extracted_text = response.choices[0].message.content.strip()
673
+
674
+ if not extracted_text:
675
+ raise HTTPException(status_code=400, detail="No readable text found in image (Fallback).")
676
+
677
+ print(f"βœ… Extracted text (Fallback GPT): '{extracted_text}'")
678
+ return extracted_text
679
+
680
+ except HTTPException:
681
+ raise
682
+ except Exception as e:
683
+ print(f"❌ Fallback OCR failed: {e}")
684
+ raise HTTPException(status_code=500, detail=f"OCR failed permanently: {str(e)}")
685
+
686
+
687
+ # %%
688
+ # ============================================================
689
+ # LLM-POWERED EXPLANATION (GROQ LLAMA 3.1) - BATCH PROCESSING
690
+ # ============================================================
691
+
692
+ def batch_generate_explanations(results: List[Dict], query_attrs: Dict, user_query: str) -> List[str]:
693
+ """Generate diverse, LLM-powered explanations for all search results in ONE API call"""
694
+
695
+ if not results:
696
+ return []
697
+
698
+ # Build context for all items (handle up to 20 items in one call)
699
+ items_context = []
700
+ for idx, r in enumerate(results, 1):
701
+ meta = metadata_collection.get(
702
+ ids=[r["image_id"]],
703
+ include=["metadatas"]
704
+ )["metadatas"][0]
705
+
706
+ matched_attrs = [v for k, v in query_attrs.items() if meta.get(k) == v]
707
+
708
+ # Compact format to save tokens
709
+ item_info = f"{idx}. {meta.get('category', 'item')} | {meta.get('metal', '?')} | {meta.get('primary_stone', '?')} | score:{r['visual_score']:.2f} | matched:{','.join(matched_attrs) if matched_attrs else 'none'}"
710
+ items_context.append(item_info)
711
+
712
+ # Compact prompt to fit more items
713
+ prompt = f"""Query: "{user_query}"
714
+
715
+ Write 1 brief sentence for EACH item:
716
+
717
+ {chr(10).join(items_context)}
718
+
719
+ Format:
720
+ 1. [sentence]
721
+ 2. [sentence]
722
+ etc."""
723
+
724
+ if groq_client is None:
725
+ explanations = []
726
+ else:
727
+ try:
728
+ # Single API call for ALL items
729
+ response = groq_client.chat.completions.create(
730
+ model="llama-3.1-8b-instant",
731
+ messages=[
732
+ {"role": "system", "content": "Write brief jewellery recommendations."},
733
+ {"role": "user", "content": prompt}
734
+ ],
735
+ temperature=0.7,
736
+ max_tokens=min(800, len(results) * 60), # Increased for 12+ items
737
+ top_p=0.9
738
+ )
739
+
740
+ # Parse response
741
+ full_response = response.choices[0].message.content.strip()
742
+ explanations = []
743
+
744
+ import re
745
+ pattern = r'^\s*(\d+)[\.:)\-]\s*(.+?)(?=^\s*\d+[\.:)\-]|\Z)'
746
+ matches = re.findall(pattern, full_response, re.MULTILINE | re.DOTALL)
747
+
748
+ if matches and len(matches) >= len(results):
749
+ for num, text in matches[:len(results)]:
750
+ clean_text = ' '.join(text.strip().split())
751
+ if clean_text and len(clean_text) > 10:
752
+ explanations.append(clean_text)
753
+
754
+ if len(explanations) >= len(results):
755
+ return explanations[:len(results)]
756
+
757
+ # If incomplete, pad with fallback
758
+ print(f"⚠️ LLM returned {len(explanations)}/{len(results)} explanations, padding with fallback")
759
+
760
+ except Exception as e:
761
+ print(f"⚠️ LLM explanation failed: {e}, using fallback")
762
+ explanations = []
763
+
764
+ # Fallback for missing explanations
765
+ while len(explanations) < len(results):
766
+ idx = len(explanations)
767
+ r = results[idx]
768
+ meta = metadata_collection.get(
769
+ ids=[r["image_id"]],
770
+ include=["metadatas"]
771
+ )["metadatas"][0]
772
+ matched_attrs = [v for k, v in query_attrs.items() if meta.get(k) == v]
773
+
774
+ category = meta.get('category', 'item')
775
+ metal = meta.get('metal', 'unknown')
776
+ stone = meta.get('primary_stone', 'unknown')
777
+
778
+ if matched_attrs and r['visual_score'] < 1.3:
779
+ explanations.append(
780
+ f"Excellent {category} featuring {' and '.join(matched_attrs)}. High visual similarity (score: {r['visual_score']:.2f})."
781
+ )
782
+ elif matched_attrs:
783
+ explanations.append(
784
+ f"Beautiful {metal} {category} with {stone}. Features {' and '.join(matched_attrs)}."
785
+ )
786
+ elif r['visual_score'] < 1.3:
787
+ explanations.append(
788
+ f"Highly similar {category} with excellent visual match. {metal.capitalize()} with {stone}."
789
+ )
790
+ else:
791
+ explanations.append(
792
+ f"Recommended {metal} {category} with {stone}. Good visual similarity."
793
+ )
794
+
795
+ return explanations
796
+ # ============================================================
797
+ # API ENDPOINTS
798
+ # ============================================================
799
+
800
+ @app.get("/health")
801
+ def health_check():
802
+ """Health check endpoint for HF Spaces monitoring"""
803
+ return {
804
+ "status": "healthy",
805
+ "models_loaded": {
806
+ "clip": clip_model is not None,
807
+ "cross_encoder": cross_encoder is not None,
808
+ "blip_captions": len(BLIP_CAPTIONS) > 0
809
+ },
810
+ "database": {
811
+ "images": image_collection.count(),
812
+ "metadata": metadata_collection.count()
813
+ }
814
+ }
815
+
816
+ @app.post("/search/text")
817
+ def search_text(req: TextSearchRequest):
818
+ # Detect intent from text
819
+ if req.query.strip():
820
+ intent = detect_intent_and_attributes(req.query)
821
+ attrs = intent["attributes"]
822
+ else:
823
+ intent = {"intent": "filter", "attributes": {}, "exclusions": {}}
824
+ attrs = {}
825
+
826
+ # === DUAL-STAGE FILTERING STRATEGY ===
827
+ # 1. Identify valid IDs from Metadata Collection (Source of Truth)
828
+ # 2. Use those IDs to filter Vector Search results
829
+
830
+ # Construct WHERE clause for Metadata Collection
831
+ where_clauses = []
832
+
833
+ if req.filters:
834
+ for key, value in req.filters.items():
835
+ where_clauses.append({key: value.lower()}) # Explicit filters
836
+
837
+ # Also apply attributes detected from text as generic filters if user didn't specify explicit ones
838
+ # (Optional: this makes "emerald ring" implies primary_stone=emerald)
839
+ # But usually we let visual search handle text unless it's strict.
840
+
841
+ final_where = None
842
+ if len(where_clauses) > 1:
843
+ final_where = {"$and": where_clauses}
844
+ elif len(where_clauses) == 1:
845
+ final_where = where_clauses[0]
846
+
847
+ valid_ids = None
848
+ if final_where:
849
+ # Fetch ALL valid IDs matching the filter
850
+ print(f"πŸ” Filtering metadata with: {final_where}")
851
+ meta_res = metadata_collection.get(where=final_where, include=["metadatas"])
852
+ if meta_res["ids"]:
853
+ valid_ids = set(meta_res["ids"])
854
+ print(f"βœ… Found {len(valid_ids)} valid items matching filters.")
855
+ else:
856
+ print("⚠️ No items match the filters.")
857
+ return {"query": req.query, "intent": attrs, "results": []}
858
+
859
+ # === EXECUTE SEARCH ===
860
+
861
+ # Case A: Filter Only (No Text Query)
862
+ if not req.query.strip() and valid_ids:
863
+ # Just return the matching items (Top K)
864
+ candidates = [{"image_id": vid, "visual_score": 0.0} for vid in list(valid_ids)[:req.top_k]]
865
+ ranked = candidates # No ranking needed without text
866
+ explanations = ["Filtered result"] * len(ranked)
867
+
868
+ # Case B: Text Query (with or without Filter)
869
+ else:
870
+ search_query = req.query if req.query.strip() else "jewellery"
871
+
872
+ # We perform a BROADER vector search, then filter in Python
873
+ # Retrieve K*5 or at least 100 to ensure we find intersections
874
+ fetch_k = 200 if valid_ids else 40
875
+
876
+ # Note: We do NOT pass 'where' to retrieve_visual_candidates because
877
+ # image_collection lacks metadata. We filter manually.
878
+ candidates = retrieve_visual_candidates(search_query, k=fetch_k)
879
+
880
+ filtered_candidates = []
881
+ for c in candidates:
882
+ if valid_ids is not None:
883
+ if c["image_id"] in valid_ids:
884
+ filtered_candidates.append(c)
885
+ else:
886
+ filtered_candidates.append(c)
887
+
888
+ # Apply strict limit now
889
+ filtered = filtered_candidates # apply_metadata_boost(filtered_candidates, attrs, {})
890
+
891
+ # Cross-encoder re-ranking
892
+ if req.use_reranking and filtered and req.query.strip():
893
+ ranked = rerank_with_cross_encoder(req.query, filtered, req.top_k)
894
+ else:
895
+ ranked = filtered[:req.top_k]
896
+
897
+ # Explanations
898
+ if req.use_explanations and req.query.strip():
899
+ explanations = batch_generate_explanations(ranked, attrs, search_query)
900
+ else:
901
+ explanations = ["Match found"] * len(ranked)
902
+
903
+ # === FORMAT RESULTS ===
904
+ results = []
905
+
906
+ # Fetch metadata for final results
907
+ if ranked:
908
+ ranked_ids = [r["image_id"] for r in ranked]
909
+ metas = metadata_collection.get(ids=ranked_ids, include=["metadatas"])["metadatas"]
910
+ meta_map = {rid: m for rid, m in zip(ranked_ids, metas)}
911
+ else:
912
+ meta_map = {}
913
+
914
+ for r, explanation in zip(ranked, explanations):
915
+ results.append({
916
+ "image_id": r["image_id"],
917
+ "explanation": explanation,
918
+ "metadata": meta_map.get(r["image_id"], {}),
919
+ "scores": {
920
+ "visual": r["visual_score"],
921
+ "final": r.get("visual_score", 0) # simplified
922
+ }
923
+ })
924
+
925
+ return {
926
+ "query": req.query,
927
+ "intent": attrs,
928
+ "results": results
929
+ }
930
+
931
+ return {
932
+ "query": req.query,
933
+ "intent": attrs,
934
+ "results": results
935
+ }
936
+
937
+ # %%
938
+ @app.post("/search/similar")
939
+ def search_similar(req: SimilarSearchRequest):
940
+ base = image_collection.get(
941
+ ids=[req.image_id],
942
+ include=["embeddings"]
943
+ )["embeddings"][0]
944
+
945
+ res = image_collection.query(
946
+ query_embeddings=[base],
947
+ n_results=req.top_k + 1
948
+ )
949
+
950
+ base_meta = metadata_collection.get(
951
+ ids=[req.image_id],
952
+ include=["metadatas"]
953
+ )["metadatas"][0]
954
+
955
+ attrs = {
956
+ k: base_meta[k]
957
+ for k in ["category", "metal", "primary_stone"]
958
+ if base_meta.get(k) != "unknown"
959
+ }
960
+
961
+ candidates = [
962
+ {
963
+ "image_id": img_id,
964
+ "visual_score": dist
965
+ }
966
+ for img_id, dist in zip(res["ids"][0], res["distances"][0])
967
+ if img_id != req.image_id
968
+ ]
969
+
970
+ ranked = apply_metadata_boost(candidates, attrs, {})[:req.top_k]
971
+
972
+ # Generate all explanations in one batch LLM call
973
+ # For similar search, use the base image ID as the query context
974
+ query_context = f"items similar to {req.image_id}"
975
+ explanations = batch_generate_explanations(ranked, attrs, query_context)
976
+
977
+ results = []
978
+ for r, explanation in zip(ranked, explanations):
979
+ results.append({
980
+ "image_id": r["image_id"],
981
+ "explanation": explanation,
982
+ "scores": {
983
+ "visual": r["visual_score"],
984
+ "metadata": r["metadata_boost"],
985
+ "final": r["final_score"]
986
+ }
987
+ })
988
+
989
+ return {
990
+ "base_image": req.image_id,
991
+ "results": results
992
+ }
993
+
994
+ # %%
995
+ # ============================================================
996
+ # IMAGE UPLOAD SEARCH ENDPOINT
997
+ # ============================================================
998
+
999
+ @app.post("/search/upload-image")
1000
+ async def search_by_uploaded_image(
1001
+ file: UploadFile = File(...),
1002
+ top_k: int = 12
1003
+ ):
1004
+ """
1005
+ Search for similar jewellery items by uploading an image.
1006
+ The image is encoded using CLIP and queried against the database.
1007
+ """
1008
+ # Validate file type
1009
+ if not file.content_type or not file.content_type.startswith('image/'):
1010
+ raise HTTPException(status_code=400, detail="File must be an image")
1011
+
1012
+ try:
1013
+ # Read image bytes
1014
+ image_bytes = await file.read()
1015
+
1016
+ # Encode image with CLIP
1017
+ query_embedding = encode_uploaded_image(image_bytes)
1018
+
1019
+ # Query ChromaDB
1020
+ res = image_collection.query(
1021
+ query_embeddings=[query_embedding.tolist()],
1022
+ n_results=min(100, top_k * 10),
1023
+ include=["distances"]
1024
+ )
1025
+
1026
+ # Get metadata for all results
1027
+ candidates = []
1028
+ for img_id, dist in zip(res["ids"][0], res["distances"][0]):
1029
+ candidates.append({
1030
+ "image_id": img_id,
1031
+ "visual_score": dist
1032
+ })
1033
+
1034
+ # Get metadata from first result to infer attributes
1035
+ if candidates:
1036
+ base_meta = metadata_collection.get(
1037
+ ids=[candidates[0]["image_id"]],
1038
+ include=["metadatas"]
1039
+ )["metadatas"][0]
1040
+
1041
+ attrs = {
1042
+ k: base_meta[k]
1043
+ for k in ["category", "metal", "primary_stone"]
1044
+ if base_meta.get(k) != "unknown"
1045
+ }
1046
+ else:
1047
+ attrs = {}
1048
+
1049
+ # Apply metadata boost (no exclusions for image upload)
1050
+ ranked = apply_metadata_boost(candidates, attrs, {})[:top_k]
1051
+
1052
+ # Generate explanations in batch
1053
+ query_context = f"items visually similar to uploaded image"
1054
+ explanations = batch_generate_explanations(ranked, attrs, query_context)
1055
+
1056
+ results = []
1057
+ for r, explanation in zip(ranked, explanations):
1058
+ results.append({
1059
+ "image_id": r["image_id"],
1060
+ "explanation": explanation,
1061
+ "scores": {
1062
+ "visual": r["visual_score"],
1063
+ "metadata": r["metadata_boost"],
1064
+ "final": r["final_score"]
1065
+ }
1066
+ })
1067
+
1068
+ return {
1069
+ "query_type": "uploaded_image",
1070
+ "filename": file.filename,
1071
+ "results": results
1072
+ }
1073
+
1074
+ except HTTPException:
1075
+ raise
1076
+ except Exception as e:
1077
+ raise HTTPException(status_code=500, detail=f"Image search failed: {str(e)}")
1078
+
1079
+
1080
+ # %%
1081
+ # ============================================================
1082
+ # OCR QUERY SEARCH ENDPOINT
1083
+ # ============================================================
1084
+
1085
+ @app.post("/search/ocr-query")
1086
+ async def search_by_ocr_query(
1087
+ file: UploadFile = File(...),
1088
+ top_k: int = 12
1089
+ ):
1090
+ """
1091
+ Extract text from uploaded image using NVIDIA NeMo OCR,
1092
+ then perform text-based search with the extracted query.
1093
+ """
1094
+ # Validate file type
1095
+ if not file.content_type or not file.content_type.startswith('image/'):
1096
+ raise HTTPException(status_code=400, detail="File must be an image")
1097
+
1098
+ try:
1099
+ # Read image bytes
1100
+ image_bytes = await file.read()
1101
+
1102
+ # Extract text using NVIDIA OCR
1103
+ extracted_text = extract_text_from_image(image_bytes)
1104
+
1105
+ print(f"πŸ“ Extracted text from image: '{extracted_text}'")
1106
+
1107
+ # Use the extracted text for normal text search
1108
+ intent = detect_intent_and_attributes(extracted_text)
1109
+ attrs = intent["attributes"]
1110
+ exclusions = intent.get("exclusions", {})
1111
+
1112
+ # Stage 1: CLIP retrieval (reduced to k=40 for HF Spaces)
1113
+ candidates = retrieve_visual_candidates(extracted_text, k=40)
1114
+
1115
+ # Stage 2: Metadata boost + exclusion filtering
1116
+ filtered = apply_metadata_boost(candidates, attrs, exclusions)
1117
+
1118
+ # Stage 3: Cross-encoder re-ranking
1119
+ ranked = rerank_with_cross_encoder(extracted_text, filtered, top_k)
1120
+
1121
+ # Generate explanations in batch
1122
+ explanations = batch_generate_explanations(ranked, attrs, extracted_text)
1123
+
1124
+ results = []
1125
+ for r, explanation in zip(ranked, explanations):
1126
+ results.append({
1127
+ "image_id": r["image_id"],
1128
+ "explanation": explanation,
1129
+ "scores": {
1130
+ "visual": r["visual_score"],
1131
+ "metadata": r["metadata_boost"],
1132
+ "final": r["final_score"]
1133
+ }
1134
+ })
1135
+
1136
+ return {
1137
+ "query_type": "ocr_extracted",
1138
+ "extracted_text": extracted_text,
1139
+ "intent": attrs,
1140
+ "results": results
1141
+ }
1142
+
1143
+ except HTTPException:
1144
+ raise
1145
+ except Exception as e:
1146
+ raise HTTPException(status_code=500, detail=f"OCR search failed: {str(e)}")
1147
+
1148
+ # %%
1149
+ @app.get("/image/{image_id}")
1150
+ def get_image(image_id: str):
1151
+ path = os.path.join(IMAGE_DIR, image_id)
1152
+ if not os.path.exists(path):
1153
+ raise HTTPException(status_code=404, detail="Image not found")
1154
+ return FileResponse(path)
1155
+
1156
+ # %%
1157
+ # ============================================================
1158
+ # RUN SERVER
1159
+ # ============================================================
1160
+
1161
+ if __name__ == "__main__":
1162
+ import uvicorn
1163
+ print("πŸš€ Starting Jewellery Search API server...")
1164
+ print(f"πŸ“ Data directory: {DATA_DIR}")
1165
+ print(f"πŸ“ Image directory: {IMAGE_DIR}")
1166
+ print(f"πŸ“ ChromaDB path: {CHROMA_PATH}")
1167
+ print(f"🌐 Server will run on: http://localhost:8000")
1168
+ print(f"πŸ“– API docs available at: http://localhost:8000/docs")
1169
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")