ocr / app.py
rththr's picture
Update app.py
5ed3d23 verified
import os
import glob
import torch
import easyocr
import zipfile # <--- Added for unzipping
import numpy as np
from PIL import Image
from io import BytesIO
from fastapi import FastAPI, File, UploadFile
from sentence_transformers import SentenceTransformer, util
app = FastAPI()
# --- CONFIG ---
DATABASE_ZIP = "database.zip"
DATABASE_PATH = "database"
CACHE_FILE = "db_cache.pt"
# --- GLOBALS ---
model = None
reader = None
db_embeddings = None
db_names = []
def load_resources():
global model, reader, db_embeddings, db_names
# 1. AUTO-UNZIP LOGIC
# Checks if zip exists and if we haven't unzipped it yet (or just to be safe)
if os.path.exists(DATABASE_ZIP):
print(f"📦 Found {DATABASE_ZIP}, checking contents...")
# We check if the folder already exists to save time, or force unzip if needed.
# Here we force unzip to ensure we have the latest data from your upload.
try:
with zipfile.ZipFile(DATABASE_ZIP, 'r') as zip_ref:
zip_ref.extractall(".")
print("✅ Unzipped successfully!")
except Exception as e:
print(f"❌ Error unzipping: {e}")
print("Loading AI Models...")
model = SentenceTransformer('clip-ViT-B-32')
print("Loading OCR...")
# Force CPU if no GPU available in Space
reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
# --- LOAD DATABASE ---
print("Indexing Database...")
# (Optional) If you want to force a re-index every time you upload a new zip,
# you can remove the cache file check. For now, we keep it.
if os.path.exists(CACHE_FILE) and not os.path.exists(DATABASE_ZIP):
# Only load cache if we didn't just upload a new zip
print("Loading from cache...")
cache_data = torch.load(CACHE_FILE)
db_embeddings = cache_data['embeddings']
db_names = cache_data['names']
else:
print("Building fresh index from images...")
temp_emb = []
temp_names = []
if not os.path.exists(DATABASE_PATH):
os.makedirs(DATABASE_PATH)
files = glob.glob(os.path.join(DATABASE_PATH, "*"))
print(f"Found {len(files)} images in folder.")
for f in files:
try:
img = Image.open(f).convert("RGB")
emb = model.encode(img, convert_to_tensor=True)
temp_emb.append(emb)
# Clean filename for the ID
name = os.path.basename(f).rsplit('.', 1)[0]
temp_names.append(name)
except Exception as e:
print(f"Skip {f}: {e}")
if temp_emb:
db_embeddings = torch.stack(temp_emb)
db_names = temp_names
torch.save({'embeddings': db_embeddings, 'names': db_names}, CACHE_FILE)
print(f"Ready! Loaded {len(db_names)} reference items.")
# Initialize on startup
load_resources()
def calculate_text_match(db_filename, ocr_text):
# Normalize DB Name
db_clean = db_filename.lower().replace("_", " ").replace("-", " ").replace(".", " ")
db_words = set(db_clean.split())
# Normalize OCR Text
ocr_clean = ocr_text.lower().replace("_", " ").replace("-", " ").replace(".", " ")
ocr_words = set(ocr_clean.split())
return len(db_words.intersection(ocr_words))
@app.get("/")
def health_check():
return {"status": "running", "database_size": len(db_names) if db_names else 0}
@app.post("/ocr")
async def identify_skin(image: UploadFile = File(...)):
# 1. Read Image
contents = await image.read()
query_img = Image.open(BytesIO(contents)).convert("RGB")
# 2. OCR (Bottom 30% Logic)
w, h = query_img.size
# Crop bottom 30% for text detection
bottom_crop = query_img.crop((0, int(h*0.70), w, h))
bottom_np = np.array(bottom_crop)
ocr_result = reader.readtext(bottom_np, detail=0)
detected_text = " ".join(ocr_result).lower()
# 3. MATCHING LOGIC
if not db_names:
return {"name": "Database Empty", "ocr_raw": detected_text, "method": "Error"}
all_scores = []
for db_name in db_names:
score = calculate_text_match(db_name, detected_text)
all_scores.append(score)
max_score = max(all_scores) if all_scores else 0
candidates = [idx for idx, score in enumerate(all_scores) if score == max_score]
final_idx = 0
method = "Visual"
# Case A: Strong Text Match
if max_score >= 2:
if len(candidates) == 1:
final_idx = candidates[0]
method = "Text Lock"
else:
# Hybrid Tie-Break
method = "Hybrid"
emb_query = model.encode(query_img, convert_to_tensor=True)
subset_emb = db_embeddings[candidates]
hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0]
local_idx = hits[0]['corpus_id']
final_idx = candidates[local_idx]
# Case B: Weak Text Match
elif max_score == 1:
method = "Visual (Filtered)"
emb_query = model.encode(query_img, convert_to_tensor=True)
subset_emb = db_embeddings[candidates]
hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0]
local_idx = hits[0]['corpus_id']
final_idx = candidates[local_idx]
# Case C: Visual Only
else:
method = "Visual Only"
emb_query = model.encode(query_img, convert_to_tensor=True)
hits = util.semantic_search(emb_query, db_embeddings, top_k=1)[0]
final_idx = hits[0]['corpus_id']
result_name = db_names[final_idx]
# Clean up name format
final_clean = result_name.lstrip(" -_").replace("_", " ").replace("-", " ")
final_clean = " ".join(final_clean.split())
return {
"name": final_clean,
"ocr_raw": detected_text,
"method": method
}