import os import glob import torch import easyocr import zipfile # <--- Added for unzipping import numpy as np from PIL import Image from io import BytesIO from fastapi import FastAPI, File, UploadFile from sentence_transformers import SentenceTransformer, util app = FastAPI() # --- CONFIG --- DATABASE_ZIP = "database.zip" DATABASE_PATH = "database" CACHE_FILE = "db_cache.pt" # --- GLOBALS --- model = None reader = None db_embeddings = None db_names = [] def load_resources(): global model, reader, db_embeddings, db_names # 1. AUTO-UNZIP LOGIC # Checks if zip exists and if we haven't unzipped it yet (or just to be safe) if os.path.exists(DATABASE_ZIP): print(f"📦 Found {DATABASE_ZIP}, checking contents...") # We check if the folder already exists to save time, or force unzip if needed. # Here we force unzip to ensure we have the latest data from your upload. try: with zipfile.ZipFile(DATABASE_ZIP, 'r') as zip_ref: zip_ref.extractall(".") print("✅ Unzipped successfully!") except Exception as e: print(f"❌ Error unzipping: {e}") print("Loading AI Models...") model = SentenceTransformer('clip-ViT-B-32') print("Loading OCR...") # Force CPU if no GPU available in Space reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) # --- LOAD DATABASE --- print("Indexing Database...") # (Optional) If you want to force a re-index every time you upload a new zip, # you can remove the cache file check. For now, we keep it. if os.path.exists(CACHE_FILE) and not os.path.exists(DATABASE_ZIP): # Only load cache if we didn't just upload a new zip print("Loading from cache...") cache_data = torch.load(CACHE_FILE) db_embeddings = cache_data['embeddings'] db_names = cache_data['names'] else: print("Building fresh index from images...") temp_emb = [] temp_names = [] if not os.path.exists(DATABASE_PATH): os.makedirs(DATABASE_PATH) files = glob.glob(os.path.join(DATABASE_PATH, "*")) print(f"Found {len(files)} images in folder.") for f in files: try: img = Image.open(f).convert("RGB") emb = model.encode(img, convert_to_tensor=True) temp_emb.append(emb) # Clean filename for the ID name = os.path.basename(f).rsplit('.', 1)[0] temp_names.append(name) except Exception as e: print(f"Skip {f}: {e}") if temp_emb: db_embeddings = torch.stack(temp_emb) db_names = temp_names torch.save({'embeddings': db_embeddings, 'names': db_names}, CACHE_FILE) print(f"Ready! Loaded {len(db_names)} reference items.") # Initialize on startup load_resources() def calculate_text_match(db_filename, ocr_text): # Normalize DB Name db_clean = db_filename.lower().replace("_", " ").replace("-", " ").replace(".", " ") db_words = set(db_clean.split()) # Normalize OCR Text ocr_clean = ocr_text.lower().replace("_", " ").replace("-", " ").replace(".", " ") ocr_words = set(ocr_clean.split()) return len(db_words.intersection(ocr_words)) @app.get("/") def health_check(): return {"status": "running", "database_size": len(db_names) if db_names else 0} @app.post("/ocr") async def identify_skin(image: UploadFile = File(...)): # 1. Read Image contents = await image.read() query_img = Image.open(BytesIO(contents)).convert("RGB") # 2. OCR (Bottom 30% Logic) w, h = query_img.size # Crop bottom 30% for text detection bottom_crop = query_img.crop((0, int(h*0.70), w, h)) bottom_np = np.array(bottom_crop) ocr_result = reader.readtext(bottom_np, detail=0) detected_text = " ".join(ocr_result).lower() # 3. MATCHING LOGIC if not db_names: return {"name": "Database Empty", "ocr_raw": detected_text, "method": "Error"} all_scores = [] for db_name in db_names: score = calculate_text_match(db_name, detected_text) all_scores.append(score) max_score = max(all_scores) if all_scores else 0 candidates = [idx for idx, score in enumerate(all_scores) if score == max_score] final_idx = 0 method = "Visual" # Case A: Strong Text Match if max_score >= 2: if len(candidates) == 1: final_idx = candidates[0] method = "Text Lock" else: # Hybrid Tie-Break method = "Hybrid" emb_query = model.encode(query_img, convert_to_tensor=True) subset_emb = db_embeddings[candidates] hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0] local_idx = hits[0]['corpus_id'] final_idx = candidates[local_idx] # Case B: Weak Text Match elif max_score == 1: method = "Visual (Filtered)" emb_query = model.encode(query_img, convert_to_tensor=True) subset_emb = db_embeddings[candidates] hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0] local_idx = hits[0]['corpus_id'] final_idx = candidates[local_idx] # Case C: Visual Only else: method = "Visual Only" emb_query = model.encode(query_img, convert_to_tensor=True) hits = util.semantic_search(emb_query, db_embeddings, top_k=1)[0] final_idx = hits[0]['corpus_id'] result_name = db_names[final_idx] # Clean up name format final_clean = result_name.lstrip(" -_").replace("_", " ").replace("-", " ") final_clean = " ".join(final_clean.split()) return { "name": final_clean, "ocr_raw": detected_text, "method": method }