rththr commited on
Commit
5ed3d23
·
verified ·
1 Parent(s): 378fd2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -31
app.py CHANGED
@@ -1,44 +1,174 @@
 
 
 
1
  import easyocr
 
2
  import numpy as np
3
- from fastapi import FastAPI, File, UploadFile
4
- from fastapi.middleware.cors import CORSMiddleware
5
  from PIL import Image
6
- import io
 
 
7
 
8
  app = FastAPI()
9
 
10
- # Enable CORS so your HTML tool can access this API
11
- app.add_middleware(
12
- CORSMiddleware,
13
- allow_origins=["*"],
14
- allow_credentials=True,
15
- allow_methods=["*"],
16
- allow_headers=["*"],
17
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- print("Loading EasyOCR Model...")
20
- # Use CPU mode for free tier (gpu=False)
21
- reader = easyocr.Reader(['en'], gpu=False)
22
- print("Model Ready!")
 
 
 
 
 
 
 
23
 
24
  @app.get("/")
25
- def home():
26
- return {"status": "EasyOCR API is Running"}
27
 
28
  @app.post("/ocr")
29
- async def extract_text(image: UploadFile = File(...)):
30
- try:
31
- # 1. Read Image
32
- contents = await image.read()
33
- img = Image.open(io.BytesIO(contents)).convert("RGB")
34
- img_np = np.array(img)
35
-
36
- # 2. Run EasyOCR
37
- result = reader.readtext(img_np, detail=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # 3. Join Text
40
- text = " ".join(result)
41
- return {"text": text}
 
 
 
 
 
 
 
 
 
42
 
43
- except Exception as e:
44
- return {"error": str(e)}
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
  import easyocr
5
+ import zipfile # <--- Added for unzipping
6
  import numpy as np
 
 
7
  from PIL import Image
8
+ from io import BytesIO
9
+ from fastapi import FastAPI, File, UploadFile
10
+ from sentence_transformers import SentenceTransformer, util
11
 
12
  app = FastAPI()
13
 
14
+ # --- CONFIG ---
15
+ DATABASE_ZIP = "database.zip"
16
+ DATABASE_PATH = "database"
17
+ CACHE_FILE = "db_cache.pt"
18
+
19
+ # --- GLOBALS ---
20
+ model = None
21
+ reader = None
22
+ db_embeddings = None
23
+ db_names = []
24
+
25
+ def load_resources():
26
+ global model, reader, db_embeddings, db_names
27
+
28
+ # 1. AUTO-UNZIP LOGIC
29
+ # Checks if zip exists and if we haven't unzipped it yet (or just to be safe)
30
+ if os.path.exists(DATABASE_ZIP):
31
+ print(f"📦 Found {DATABASE_ZIP}, checking contents...")
32
+ # We check if the folder already exists to save time, or force unzip if needed.
33
+ # Here we force unzip to ensure we have the latest data from your upload.
34
+ try:
35
+ with zipfile.ZipFile(DATABASE_ZIP, 'r') as zip_ref:
36
+ zip_ref.extractall(".")
37
+ print("✅ Unzipped successfully!")
38
+ except Exception as e:
39
+ print(f"❌ Error unzipping: {e}")
40
+
41
+ print("Loading AI Models...")
42
+ model = SentenceTransformer('clip-ViT-B-32')
43
+
44
+ print("Loading OCR...")
45
+ # Force CPU if no GPU available in Space
46
+ reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
47
+
48
+ # --- LOAD DATABASE ---
49
+ print("Indexing Database...")
50
+
51
+ # (Optional) If you want to force a re-index every time you upload a new zip,
52
+ # you can remove the cache file check. For now, we keep it.
53
+ if os.path.exists(CACHE_FILE) and not os.path.exists(DATABASE_ZIP):
54
+ # Only load cache if we didn't just upload a new zip
55
+ print("Loading from cache...")
56
+ cache_data = torch.load(CACHE_FILE)
57
+ db_embeddings = cache_data['embeddings']
58
+ db_names = cache_data['names']
59
+ else:
60
+ print("Building fresh index from images...")
61
+ temp_emb = []
62
+ temp_names = []
63
+
64
+ if not os.path.exists(DATABASE_PATH):
65
+ os.makedirs(DATABASE_PATH)
66
+
67
+ files = glob.glob(os.path.join(DATABASE_PATH, "*"))
68
+ print(f"Found {len(files)} images in folder.")
69
+
70
+ for f in files:
71
+ try:
72
+ img = Image.open(f).convert("RGB")
73
+ emb = model.encode(img, convert_to_tensor=True)
74
+ temp_emb.append(emb)
75
+ # Clean filename for the ID
76
+ name = os.path.basename(f).rsplit('.', 1)[0]
77
+ temp_names.append(name)
78
+ except Exception as e:
79
+ print(f"Skip {f}: {e}")
80
+
81
+ if temp_emb:
82
+ db_embeddings = torch.stack(temp_emb)
83
+ db_names = temp_names
84
+ torch.save({'embeddings': db_embeddings, 'names': db_names}, CACHE_FILE)
85
+
86
+ print(f"Ready! Loaded {len(db_names)} reference items.")
87
 
88
+ # Initialize on startup
89
+ load_resources()
90
+
91
+ def calculate_text_match(db_filename, ocr_text):
92
+ # Normalize DB Name
93
+ db_clean = db_filename.lower().replace("_", " ").replace("-", " ").replace(".", " ")
94
+ db_words = set(db_clean.split())
95
+ # Normalize OCR Text
96
+ ocr_clean = ocr_text.lower().replace("_", " ").replace("-", " ").replace(".", " ")
97
+ ocr_words = set(ocr_clean.split())
98
+ return len(db_words.intersection(ocr_words))
99
 
100
  @app.get("/")
101
+ def health_check():
102
+ return {"status": "running", "database_size": len(db_names) if db_names else 0}
103
 
104
  @app.post("/ocr")
105
+ async def identify_skin(image: UploadFile = File(...)):
106
+ # 1. Read Image
107
+ contents = await image.read()
108
+ query_img = Image.open(BytesIO(contents)).convert("RGB")
109
+
110
+ # 2. OCR (Bottom 30% Logic)
111
+ w, h = query_img.size
112
+ # Crop bottom 30% for text detection
113
+ bottom_crop = query_img.crop((0, int(h*0.70), w, h))
114
+ bottom_np = np.array(bottom_crop)
115
+
116
+ ocr_result = reader.readtext(bottom_np, detail=0)
117
+ detected_text = " ".join(ocr_result).lower()
118
+
119
+ # 3. MATCHING LOGIC
120
+ if not db_names:
121
+ return {"name": "Database Empty", "ocr_raw": detected_text, "method": "Error"}
122
+
123
+ all_scores = []
124
+ for db_name in db_names:
125
+ score = calculate_text_match(db_name, detected_text)
126
+ all_scores.append(score)
127
+
128
+ max_score = max(all_scores) if all_scores else 0
129
+ candidates = [idx for idx, score in enumerate(all_scores) if score == max_score]
130
+
131
+ final_idx = 0
132
+ method = "Visual"
133
+
134
+ # Case A: Strong Text Match
135
+ if max_score >= 2:
136
+ if len(candidates) == 1:
137
+ final_idx = candidates[0]
138
+ method = "Text Lock"
139
+ else:
140
+ # Hybrid Tie-Break
141
+ method = "Hybrid"
142
+ emb_query = model.encode(query_img, convert_to_tensor=True)
143
+ subset_emb = db_embeddings[candidates]
144
+ hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0]
145
+ local_idx = hits[0]['corpus_id']
146
+ final_idx = candidates[local_idx]
147
+
148
+ # Case B: Weak Text Match
149
+ elif max_score == 1:
150
+ method = "Visual (Filtered)"
151
+ emb_query = model.encode(query_img, convert_to_tensor=True)
152
+ subset_emb = db_embeddings[candidates]
153
+ hits = util.semantic_search(emb_query, subset_emb, top_k=1)[0]
154
+ local_idx = hits[0]['corpus_id']
155
+ final_idx = candidates[local_idx]
156
 
157
+ # Case C: Visual Only
158
+ else:
159
+ method = "Visual Only"
160
+ emb_query = model.encode(query_img, convert_to_tensor=True)
161
+ hits = util.semantic_search(emb_query, db_embeddings, top_k=1)[0]
162
+ final_idx = hits[0]['corpus_id']
163
+
164
+ result_name = db_names[final_idx]
165
+
166
+ # Clean up name format
167
+ final_clean = result_name.lstrip(" -_").replace("_", " ").replace("-", " ")
168
+ final_clean = " ".join(final_clean.split())
169
 
170
+ return {
171
+ "name": final_clean,
172
+ "ocr_raw": detected_text,
173
+ "method": method
174
+ }