|
|
import os |
|
|
import glob |
|
|
import torch |
|
|
import easyocr |
|
|
import zipfile |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
DATABASE_ZIP = "database.zip" |
|
|
DATABASE_PATH = "database" |
|
|
CACHE_FILE = "db_cache.pt" |
|
|
|
|
|
|
|
|
model = None |
|
|
reader = None |
|
|
db_embeddings = None |
|
|
db_names = [] |
|
|
|
|
|
def load_resources(): |
|
|
global model, reader, db_embeddings, db_names |
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(DATABASE_ZIP): |
|
|
print(f"📦 Found {DATABASE_ZIP}, checking contents...") |
|
|
|
|
|
|
|
|
try: |
|
|
with zipfile.ZipFile(DATABASE_ZIP, 'r') as zip_ref: |
|
|
zip_ref.extractall(".") |
|
|
print("✅ Unzipped successfully!") |
|
|
except Exception as e: |
|
|
print(f"❌ Error unzipping: {e}") |
|
|
|
|
|
print("Loading AI Models...") |
|
|
model = SentenceTransformer('clip-ViT-B-32') |
|
|
|
|
|
print("Loading OCR...") |
|
|
|
|
|
reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) |
|
|
|
|
|
|
|
|
print("Indexing Database...") |
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(CACHE_FILE) and not os.path.exists(DATABASE_ZIP): |
|
|
|
|
|
print("Loading from cache...") |
|
|
cache_data = torch.load(CACHE_FILE) |
|
|
db_embeddings = cache_data['embeddings'] |
|
|
db_names = cache_data['names'] |
|
|
else: |
|
|
print("Building fresh index from images...") |
|
|
temp_emb = [] |
|
|
temp_names = [] |
|
|
|
|
|
if not os.path.exists(DATABASE_PATH): |
|
|
os.makedirs(DATABASE_PATH) |
|
|
|
|
|
files = glob.glob(os.path.join(DATABASE_PATH, "*")) |
|
|
print(f"Found {len(files)} images in folder.") |
|
|
|
|
|
for f in files: |
|
|
try: |
|
|
img = Image.open(f).convert("RGB") |
|
|
emb = model.encode(img, convert_to_tensor=True) |
|
|
temp_emb.append(emb) |
|
|
|
|
|
name = os.path.basename(f).rsplit('.', 1)[0] |
|
|
temp_names.append(name) |
|
|
except Exception as e: |
|
|
print(f"Skip {f}: {e}") |
|
|
|
|
|
if temp_emb: |
|
|
db_embeddings = torch.stack(temp_emb) |
|
|
db_names = temp_names |
|
|
torch.save({'embeddings': db_embeddings, 'names': db_names}, CACHE_FILE) |
|
|
|
|
|
print(f"Ready! Loaded {len(db_names)} reference items.") |
|
|
|
|
|
|
|
|
load_resources() |
|
|
|
|
|
def calculate_text_match(db_filename, ocr_text): |
|
|
|
|
|
db_clean = db_filename.lower().replace("_", " ").replace("-", " ").replace(".", " ") |
|
|
db_words = set(db_clean.split()) |
|
|
|
|
|
ocr_clean = ocr_text.lower().replace("_", " ").replace("-", " ").replace(".", " ") |
|
|
ocr_words = set(ocr_clean.split()) |
|
|
return len(db_words.intersection(ocr_words)) |
|
|
|
|
|
@app.get("/") |
|
|
def health_check(): |
|
|
return {"status": "running", "database_size": len(db_names) if db_names else 0} |
|
|
|
|
|
@app.post("/ocr") |
|
|
async def identify_skin(image: UploadFile = File(...)): |
|
|
|
|
|
contents = await image.read() |
|
|
query_img = Image.open(BytesIO(contents)).convert("RGB") |
|
|
|
|
|
|
|
|
w, h = query_img.size |
|
|
|
|
|
bottom_crop = query_img.crop((0, int(h*0.70), w, h)) |
|
|
bottom_np = np.array(bottom_crop) |
|
|
|
|
|
ocr_result = reader.readtext(bottom_np, detail=0) |
|
|
detected_text = " ".join(ocr_result).lower() |
|
|
|
|
|
|
|
|
if not db_names: |
|
|
return {"name": "Database Empty", "ocr_raw": detected_text, "method": "Error"} |
|
|
|
|
|
all_scores = [] |
|
|
for db_name in db_names: |
|
|
score = calculate_text_match(db_name, detected_text) |
|
|
all_scores.append(score) |
|
|
|
|
|
max_score = max(all_scores) if all_scores else 0 |
|
|
candidates = [idx for idx, score in enumerate(all_scores) if score == max_score] |
|
|
|
|
|
final_idx = 0 |
|
|
method = "Visual" |
|
|
|
|
|
|
|
|
if max_score >= 2: |
|
|
if len(candidates) == 1: |
|
|
final_idx = candidates[0] |
|
|
method = "Text Lock" |
|
|
else: |
|
|
|
|
|
method = "Hybrid" |
|
|
emb_query = model.encode(query_img, convert_to_tensor=True) |
|
|
subset_emb = db_embeddings[candidates] |
|
|
hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0] |
|
|
local_idx = hits[0]['corpus_id'] |
|
|
final_idx = candidates[local_idx] |
|
|
|
|
|
|
|
|
elif max_score == 1: |
|
|
method = "Visual (Filtered)" |
|
|
emb_query = model.encode(query_img, convert_to_tensor=True) |
|
|
subset_emb = db_embeddings[candidates] |
|
|
hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0] |
|
|
local_idx = hits[0]['corpus_id'] |
|
|
final_idx = candidates[local_idx] |
|
|
|
|
|
|
|
|
else: |
|
|
method = "Visual Only" |
|
|
emb_query = model.encode(query_img, convert_to_tensor=True) |
|
|
hits = util.semantic_search(emb_query, db_embeddings, top_k=1)[0] |
|
|
final_idx = hits[0]['corpus_id'] |
|
|
|
|
|
result_name = db_names[final_idx] |
|
|
|
|
|
|
|
|
final_clean = result_name.lstrip(" -_").replace("_", " ").replace("-", " ") |
|
|
final_clean = " ".join(final_clean.split()) |
|
|
|
|
|
return { |
|
|
"name": final_clean, |
|
|
"ocr_raw": detected_text, |
|
|
"method": method |
|
|
} |