import time import torch from transformers import CLIPProcessor, CLIPModel from PIL import Image import numpy as np from config import Config from utils.gpu_diagnostics import log_model_device class CLIPClassifier: def __init__(self): self.config = Config() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = None self.processor = None self._loaded = False print("[INFO] CLIPClassifier created (lazy — model will load on first use)", flush=True) def _ensure_loaded(self): """Lazily load CLIP model and processor on first use, with fallback.""" if self._loaded: return model_name = getattr(self.config, 'CLIP_MODEL', 'openai/clip-vit-base-patch32') try: _t0 = time.time() print(f"[CLIP LAZY] Step 1/4 — Loading CLIPModel: {model_name}...", flush=True) import os HF_TOKEN = os.getenv("HF_TOKEN") self.model = CLIPModel.from_pretrained(model_name, token=HF_TOKEN) print(f"[CLIP LAZY] Step 2/4 — CLIPModel loaded in {time.time()-_t0:.1f}s. Loading CLIPProcessor...", flush=True) _t1 = time.time() self.processor = CLIPProcessor.from_pretrained(model_name, token=HF_TOKEN) print(f"[CLIP LAZY] Step 3/4 — CLIPProcessor loaded in {time.time()-_t1:.1f}s. Moving to {self.device}...", flush=True) _t2 = time.time() self.model.to(self.device) self.model.eval() log_model_device("CLIP script classifier", self.device) print(f"[CLIP LAZY] Step 4/4 — CLIP ready on {self.device} — total {time.time()-_t0:.1f}s", flush=True) self._loaded = True except Exception as e: print(f"[WARN] Failed to load CLIP model '{model_name}': {e}", flush=True) fallback_name = "openai/clip-vit-base-patch32" try: _t0 = time.time() print(f"[CLIP LAZY] Fallback 1/2 — Loading: {fallback_name}...", flush=True) import os HF_TOKEN = os.getenv("HF_TOKEN") self.model = CLIPModel.from_pretrained(fallback_name, token=HF_TOKEN) self.processor = CLIPProcessor.from_pretrained(fallback_name, token=HF_TOKEN) print(f"[CLIP LAZY] Fallback 2/2 — Moving to {self.device}...", flush=True) self.model.to(self.device) self.model.eval() log_model_device("CLIP script classifier (fallback)", self.device) print(f"[CLIP LAZY] Fallback CLIP ready — total {time.time()-_t0:.1f}s", flush=True) self._loaded = True except Exception as fe: print(f"[ERROR] Failed to load fallback CLIP model: {fe}", flush=True) @property def pipeline(self): """Property checked in app.py/test.py to ensure model is initialized""" return self.model if self.model is not None else None @property def is_loaded(self): """Check if model has been lazily loaded yet.""" return self._loaded def classify_script_type(self, image): """Classify script type of image into one of the four supported categories""" self._ensure_loaded() if not self.pipeline: return "unknown", 0.0 try: if isinstance(image, np.ndarray): image = Image.fromarray(image) # Prompts representing the four classes scripts = ["egyptian", "greek", "latin", "cuneiform"] descriptions = [ "ancient Egyptian hieroglyphic writing with drawings of animals and humans", "ancient Greek alphabet script on papyrus or stone with polytonic symbols", "medieval Latin manuscript text written in ink on parchment", "ancient Mesopotamian cuneiform tablet with wedge-shaped markings in clay" ] inputs = self.processor( text=descriptions, images=image, return_tensors="pt", padding=True ).to(self.device) with torch.inference_mode(): outputs = self.model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1).cpu().numpy()[0] best_idx = np.argmax(probs) score = float(probs[best_idx]) script_label = scripts[best_idx] print(f"[INFO] CLIP script classification: {script_label} ({score:.3f})") return script_label, score except Exception as e: print(f"[ERROR] CLIP script classification failed: {e}") return "unknown", 0.0 def classify_symbols(self, crops, candidate_labels): """Classify segmented symbol image crops against candidate labels""" self._ensure_loaded() if not self.pipeline or not crops or not candidate_labels: return [None] * len(crops) if crops else [] try: print(f"[INFO] Batch classifying {len(crops)} crops using CLIP...") # Format candidate labels into descriptive prompts for better visual matching prompts = [f"an ancient Egyptian hieroglyph symbol of a {label.replace('_', ' ')}" for label in candidate_labels] # Tokenize prompts once text_inputs = self.processor( text=prompts, return_tensors="pt", padding=True ).to(self.device) with torch.inference_mode(): text_features = self.model.get_text_features(**text_inputs) text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) results = [] # Process crops (images) for crop in crops: if isinstance(crop, np.ndarray): crop = Image.fromarray(crop) image_inputs = self.processor(images=crop, return_tensors="pt").to(self.device) with torch.inference_mode(): image_features = self.model.get_image_features(**image_inputs) image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) # Compute cosine similarities similarities = (image_features @ text_features.T).squeeze(0) best_idx = torch.argmax(similarities).item() results.append(candidate_labels[best_idx]) return results except Exception as e: print(f"[ERROR] CLIP symbol classification failed: {e}") return [candidate_labels[0]] * len(crops) if candidate_labels else [None] * len(crops)