File size: 5,910 Bytes
5ed3d23
 
 
19c05cc
5ed3d23
19c05cc
 
5ed3d23
 
 
19c05cc
 
 
5ed3d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19c05cc
5ed3d23
 
 
 
 
 
 
 
 
 
 
19c05cc
 
5ed3d23
 
19c05cc
 
5ed3d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19c05cc
5ed3d23
 
 
 
 
 
 
 
 
 
 
 
19c05cc
5ed3d23
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import glob
import torch
import easyocr
import zipfile  # <--- Added for unzipping
import numpy as np
from PIL import Image
from io import BytesIO
from fastapi import FastAPI, File, UploadFile
from sentence_transformers import SentenceTransformer, util

app = FastAPI()

# --- CONFIG ---
DATABASE_ZIP = "database.zip"
DATABASE_PATH = "database" 
CACHE_FILE = "db_cache.pt"

# --- GLOBALS ---
model = None
reader = None
db_embeddings = None
db_names = []

def load_resources():
    global model, reader, db_embeddings, db_names
    
    # 1. AUTO-UNZIP LOGIC
    # Checks if zip exists and if we haven't unzipped it yet (or just to be safe)
    if os.path.exists(DATABASE_ZIP):
        print(f"📦 Found {DATABASE_ZIP}, checking contents...")
        # We check if the folder already exists to save time, or force unzip if needed.
        # Here we force unzip to ensure we have the latest data from your upload.
        try:
            with zipfile.ZipFile(DATABASE_ZIP, 'r') as zip_ref:
                zip_ref.extractall(".") 
            print("✅ Unzipped successfully!")
        except Exception as e:
            print(f"❌ Error unzipping: {e}")

    print("Loading AI Models...")
    model = SentenceTransformer('clip-ViT-B-32')
    
    print("Loading OCR...")
    # Force CPU if no GPU available in Space
    reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())

    # --- LOAD DATABASE ---
    print("Indexing Database...")
    
    # (Optional) If you want to force a re-index every time you upload a new zip,
    # you can remove the cache file check. For now, we keep it.
    if os.path.exists(CACHE_FILE) and not os.path.exists(DATABASE_ZIP):
        # Only load cache if we didn't just upload a new zip
        print("Loading from cache...")
        cache_data = torch.load(CACHE_FILE)
        db_embeddings = cache_data['embeddings']
        db_names = cache_data['names']
    else:
        print("Building fresh index from images...")
        temp_emb = []
        temp_names = []
        
        if not os.path.exists(DATABASE_PATH):
            os.makedirs(DATABASE_PATH)
            
        files = glob.glob(os.path.join(DATABASE_PATH, "*"))
        print(f"Found {len(files)} images in folder.")

        for f in files:
            try:
                img = Image.open(f).convert("RGB")
                emb = model.encode(img, convert_to_tensor=True)
                temp_emb.append(emb)
                # Clean filename for the ID
                name = os.path.basename(f).rsplit('.', 1)[0]
                temp_names.append(name)
            except Exception as e:
                print(f"Skip {f}: {e}")

        if temp_emb:
            db_embeddings = torch.stack(temp_emb)
            db_names = temp_names
            torch.save({'embeddings': db_embeddings, 'names': db_names}, CACHE_FILE)
    
    print(f"Ready! Loaded {len(db_names)} reference items.")

# Initialize on startup
load_resources()

def calculate_text_match(db_filename, ocr_text):
    # Normalize DB Name
    db_clean = db_filename.lower().replace("_", " ").replace("-", " ").replace(".", " ")
    db_words = set(db_clean.split())
    # Normalize OCR Text
    ocr_clean = ocr_text.lower().replace("_", " ").replace("-", " ").replace(".", " ")
    ocr_words = set(ocr_clean.split())
    return len(db_words.intersection(ocr_words))

@app.get("/")
def health_check():
    return {"status": "running", "database_size": len(db_names) if db_names else 0}

@app.post("/ocr")
async def identify_skin(image: UploadFile = File(...)):
    # 1. Read Image
    contents = await image.read()
    query_img = Image.open(BytesIO(contents)).convert("RGB")
    
    # 2. OCR (Bottom 30% Logic)
    w, h = query_img.size
    # Crop bottom 30% for text detection
    bottom_crop = query_img.crop((0, int(h*0.70), w, h))
    bottom_np = np.array(bottom_crop)
    
    ocr_result = reader.readtext(bottom_np, detail=0)
    detected_text = " ".join(ocr_result).lower()
    
    # 3. MATCHING LOGIC
    if not db_names:
        return {"name": "Database Empty", "ocr_raw": detected_text, "method": "Error"}

    all_scores = []
    for db_name in db_names:
        score = calculate_text_match(db_name, detected_text)
        all_scores.append(score)
    
    max_score = max(all_scores) if all_scores else 0
    candidates = [idx for idx, score in enumerate(all_scores) if score == max_score]
    
    final_idx = 0
    method = "Visual"
    
    # Case A: Strong Text Match
    if max_score >= 2:
        if len(candidates) == 1:
            final_idx = candidates[0]
            method = "Text Lock"
        else:
            # Hybrid Tie-Break
            method = "Hybrid"
            emb_query = model.encode(query_img, convert_to_tensor=True)
            subset_emb = db_embeddings[candidates]
            hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0]
            local_idx = hits[0]['corpus_id']
            final_idx = candidates[local_idx]
            
    # Case B: Weak Text Match
    elif max_score == 1:
        method = "Visual (Filtered)"
        emb_query = model.encode(query_img, convert_to_tensor=True)
        subset_emb = db_embeddings[candidates]
        hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0]
        local_idx = hits[0]['corpus_id']
        final_idx = candidates[local_idx]
        
    # Case C: Visual Only
    else:
        method = "Visual Only"
        emb_query = model.encode(query_img, convert_to_tensor=True)
        hits = util.semantic_search(emb_query, db_embeddings, top_k=1)[0]
        final_idx = hits[0]['corpus_id']

    result_name = db_names[final_idx]
    
    # Clean up name format
    final_clean = result_name.lstrip(" -_").replace("_", " ").replace("-", " ")
    final_clean = " ".join(final_clean.split())

    return {
        "name": final_clean,
        "ocr_raw": detected_text,
        "method": method
    }