decipherai-api / models /clip_classifier.py
Akshay30's picture
Fix Greek OCR and update Latin OCR model
36331c6
import time
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
from config import Config
from utils.gpu_diagnostics import log_model_device
class CLIPClassifier:
def __init__(self):
self.config = Config()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.processor = None
self._loaded = False
print("[INFO] CLIPClassifier created (lazy β€” model will load on first use)", flush=True)
def _ensure_loaded(self):
"""Lazily load CLIP model and processor on first use, with fallback."""
if self._loaded:
return
model_name = getattr(self.config, 'CLIP_MODEL', 'openai/clip-vit-base-patch32')
try:
_t0 = time.time()
print(f"[CLIP LAZY] Step 1/4 β€” Loading CLIPModel: {model_name}...", flush=True)
import os
HF_TOKEN = os.getenv("HF_TOKEN")
self.model = CLIPModel.from_pretrained(model_name, token=HF_TOKEN)
print(f"[CLIP LAZY] Step 2/4 β€” CLIPModel loaded in {time.time()-_t0:.1f}s. Loading CLIPProcessor...", flush=True)
_t1 = time.time()
self.processor = CLIPProcessor.from_pretrained(model_name, token=HF_TOKEN)
print(f"[CLIP LAZY] Step 3/4 β€” CLIPProcessor loaded in {time.time()-_t1:.1f}s. Moving to {self.device}...", flush=True)
_t2 = time.time()
self.model.to(self.device)
self.model.eval()
log_model_device("CLIP script classifier", self.device)
print(f"[CLIP LAZY] Step 4/4 β€” CLIP ready on {self.device} β€” total {time.time()-_t0:.1f}s", flush=True)
self._loaded = True
except Exception as e:
print(f"[WARN] Failed to load CLIP model '{model_name}': {e}", flush=True)
fallback_name = "openai/clip-vit-base-patch32"
try:
_t0 = time.time()
print(f"[CLIP LAZY] Fallback 1/2 β€” Loading: {fallback_name}...", flush=True)
import os
HF_TOKEN = os.getenv("HF_TOKEN")
self.model = CLIPModel.from_pretrained(fallback_name, token=HF_TOKEN)
self.processor = CLIPProcessor.from_pretrained(fallback_name, token=HF_TOKEN)
print(f"[CLIP LAZY] Fallback 2/2 β€” Moving to {self.device}...", flush=True)
self.model.to(self.device)
self.model.eval()
log_model_device("CLIP script classifier (fallback)", self.device)
print(f"[CLIP LAZY] Fallback CLIP ready β€” total {time.time()-_t0:.1f}s", flush=True)
self._loaded = True
except Exception as fe:
print(f"[ERROR] Failed to load fallback CLIP model: {fe}", flush=True)
@property
def pipeline(self):
"""Property checked in app.py/test.py to ensure model is initialized"""
return self.model if self.model is not None else None
@property
def is_loaded(self):
"""Check if model has been lazily loaded yet."""
return self._loaded
def classify_script_type(self, image):
"""Classify script type of image into one of the four supported categories"""
self._ensure_loaded()
if not self.pipeline:
return "unknown", 0.0
try:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Prompts representing the four classes
scripts = ["egyptian", "greek", "latin", "cuneiform"]
descriptions = [
"ancient Egyptian hieroglyphic writing with drawings of animals and humans",
"ancient Greek alphabet script on papyrus or stone with polytonic symbols",
"medieval Latin manuscript text written in ink on parchment",
"ancient Mesopotamian cuneiform tablet with wedge-shaped markings in clay"
]
inputs = self.processor(
text=descriptions,
images=image,
return_tensors="pt",
padding=True
).to(self.device)
with torch.inference_mode():
outputs = self.model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
best_idx = np.argmax(probs)
score = float(probs[best_idx])
script_label = scripts[best_idx]
print(f"[INFO] CLIP script classification: {script_label} ({score:.3f})")
return script_label, score
except Exception as e:
print(f"[ERROR] CLIP script classification failed: {e}")
return "unknown", 0.0
def classify_symbols(self, crops, candidate_labels):
"""Classify segmented symbol image crops against candidate labels"""
self._ensure_loaded()
if not self.pipeline or not crops or not candidate_labels:
return [None] * len(crops) if crops else []
try:
print(f"[INFO] Batch classifying {len(crops)} crops using CLIP...")
# Format candidate labels into descriptive prompts for better visual matching
prompts = [f"an ancient Egyptian hieroglyph symbol of a {label.replace('_', ' ')}" for label in candidate_labels]
# Tokenize prompts once
text_inputs = self.processor(
text=prompts,
return_tensors="pt",
padding=True
).to(self.device)
with torch.inference_mode():
text_features = self.model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
results = []
# Process crops (images)
for crop in crops:
if isinstance(crop, np.ndarray):
crop = Image.fromarray(crop)
image_inputs = self.processor(images=crop, return_tensors="pt").to(self.device)
with torch.inference_mode():
image_features = self.model.get_image_features(**image_inputs)
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
# Compute cosine similarities
similarities = (image_features @ text_features.T).squeeze(0)
best_idx = torch.argmax(similarities).item()
results.append(candidate_labels[best_idx])
return results
except Exception as e:
print(f"[ERROR] CLIP symbol classification failed: {e}")
return [candidate_labels[0]] * len(crops) if candidate_labels else [None] * len(crops)