Spaces:
Sleeping
Sleeping
| import time | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import numpy as np | |
| from config import Config | |
| from utils.gpu_diagnostics import log_model_device | |
| class CLIPClassifier: | |
| def __init__(self): | |
| self.config = Config() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = None | |
| self.processor = None | |
| self._loaded = False | |
| print("[INFO] CLIPClassifier created (lazy β model will load on first use)", flush=True) | |
| def _ensure_loaded(self): | |
| """Lazily load CLIP model and processor on first use, with fallback.""" | |
| if self._loaded: | |
| return | |
| model_name = getattr(self.config, 'CLIP_MODEL', 'openai/clip-vit-base-patch32') | |
| try: | |
| _t0 = time.time() | |
| print(f"[CLIP LAZY] Step 1/4 β Loading CLIPModel: {model_name}...", flush=True) | |
| import os | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| self.model = CLIPModel.from_pretrained(model_name, token=HF_TOKEN) | |
| print(f"[CLIP LAZY] Step 2/4 β CLIPModel loaded in {time.time()-_t0:.1f}s. Loading CLIPProcessor...", flush=True) | |
| _t1 = time.time() | |
| self.processor = CLIPProcessor.from_pretrained(model_name, token=HF_TOKEN) | |
| print(f"[CLIP LAZY] Step 3/4 β CLIPProcessor loaded in {time.time()-_t1:.1f}s. Moving to {self.device}...", flush=True) | |
| _t2 = time.time() | |
| self.model.to(self.device) | |
| self.model.eval() | |
| log_model_device("CLIP script classifier", self.device) | |
| print(f"[CLIP LAZY] Step 4/4 β CLIP ready on {self.device} β total {time.time()-_t0:.1f}s", flush=True) | |
| self._loaded = True | |
| except Exception as e: | |
| print(f"[WARN] Failed to load CLIP model '{model_name}': {e}", flush=True) | |
| fallback_name = "openai/clip-vit-base-patch32" | |
| try: | |
| _t0 = time.time() | |
| print(f"[CLIP LAZY] Fallback 1/2 β Loading: {fallback_name}...", flush=True) | |
| import os | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| self.model = CLIPModel.from_pretrained(fallback_name, token=HF_TOKEN) | |
| self.processor = CLIPProcessor.from_pretrained(fallback_name, token=HF_TOKEN) | |
| print(f"[CLIP LAZY] Fallback 2/2 β Moving to {self.device}...", flush=True) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| log_model_device("CLIP script classifier (fallback)", self.device) | |
| print(f"[CLIP LAZY] Fallback CLIP ready β total {time.time()-_t0:.1f}s", flush=True) | |
| self._loaded = True | |
| except Exception as fe: | |
| print(f"[ERROR] Failed to load fallback CLIP model: {fe}", flush=True) | |
| def pipeline(self): | |
| """Property checked in app.py/test.py to ensure model is initialized""" | |
| return self.model if self.model is not None else None | |
| def is_loaded(self): | |
| """Check if model has been lazily loaded yet.""" | |
| return self._loaded | |
| def classify_script_type(self, image): | |
| """Classify script type of image into one of the four supported categories""" | |
| self._ensure_loaded() | |
| if not self.pipeline: | |
| return "unknown", 0.0 | |
| try: | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Prompts representing the four classes | |
| scripts = ["egyptian", "greek", "latin", "cuneiform"] | |
| descriptions = [ | |
| "ancient Egyptian hieroglyphic writing with drawings of animals and humans", | |
| "ancient Greek alphabet script on papyrus or stone with polytonic symbols", | |
| "medieval Latin manuscript text written in ink on parchment", | |
| "ancient Mesopotamian cuneiform tablet with wedge-shaped markings in clay" | |
| ] | |
| inputs = self.processor( | |
| text=descriptions, | |
| images=image, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(self.device) | |
| with torch.inference_mode(): | |
| outputs = self.model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1).cpu().numpy()[0] | |
| best_idx = np.argmax(probs) | |
| score = float(probs[best_idx]) | |
| script_label = scripts[best_idx] | |
| print(f"[INFO] CLIP script classification: {script_label} ({score:.3f})") | |
| return script_label, score | |
| except Exception as e: | |
| print(f"[ERROR] CLIP script classification failed: {e}") | |
| return "unknown", 0.0 | |
| def classify_symbols(self, crops, candidate_labels): | |
| """Classify segmented symbol image crops against candidate labels""" | |
| self._ensure_loaded() | |
| if not self.pipeline or not crops or not candidate_labels: | |
| return [None] * len(crops) if crops else [] | |
| try: | |
| print(f"[INFO] Batch classifying {len(crops)} crops using CLIP...") | |
| # Format candidate labels into descriptive prompts for better visual matching | |
| prompts = [f"an ancient Egyptian hieroglyph symbol of a {label.replace('_', ' ')}" for label in candidate_labels] | |
| # Tokenize prompts once | |
| text_inputs = self.processor( | |
| text=prompts, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(self.device) | |
| with torch.inference_mode(): | |
| text_features = self.model.get_text_features(**text_inputs) | |
| text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) | |
| results = [] | |
| # Process crops (images) | |
| for crop in crops: | |
| if isinstance(crop, np.ndarray): | |
| crop = Image.fromarray(crop) | |
| image_inputs = self.processor(images=crop, return_tensors="pt").to(self.device) | |
| with torch.inference_mode(): | |
| image_features = self.model.get_image_features(**image_inputs) | |
| image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) | |
| # Compute cosine similarities | |
| similarities = (image_features @ text_features.T).squeeze(0) | |
| best_idx = torch.argmax(similarities).item() | |
| results.append(candidate_labels[best_idx]) | |
| return results | |
| except Exception as e: | |
| print(f"[ERROR] CLIP symbol classification failed: {e}") | |
| return [candidate_labels[0]] * len(crops) if candidate_labels else [None] * len(crops) | |