Spaces:
Sleeping
Sleeping
File size: 6,972 Bytes
7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 36331c6 7b9f40a 36331c6 7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 36331c6 7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 2f4af3f 7b9f40a 2f4af3f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import time
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
from config import Config
from utils.gpu_diagnostics import log_model_device
class CLIPClassifier:
def __init__(self):
self.config = Config()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.processor = None
self._loaded = False
print("[INFO] CLIPClassifier created (lazy β model will load on first use)", flush=True)
def _ensure_loaded(self):
"""Lazily load CLIP model and processor on first use, with fallback."""
if self._loaded:
return
model_name = getattr(self.config, 'CLIP_MODEL', 'openai/clip-vit-base-patch32')
try:
_t0 = time.time()
print(f"[CLIP LAZY] Step 1/4 β Loading CLIPModel: {model_name}...", flush=True)
import os
HF_TOKEN = os.getenv("HF_TOKEN")
self.model = CLIPModel.from_pretrained(model_name, token=HF_TOKEN)
print(f"[CLIP LAZY] Step 2/4 β CLIPModel loaded in {time.time()-_t0:.1f}s. Loading CLIPProcessor...", flush=True)
_t1 = time.time()
self.processor = CLIPProcessor.from_pretrained(model_name, token=HF_TOKEN)
print(f"[CLIP LAZY] Step 3/4 β CLIPProcessor loaded in {time.time()-_t1:.1f}s. Moving to {self.device}...", flush=True)
_t2 = time.time()
self.model.to(self.device)
self.model.eval()
log_model_device("CLIP script classifier", self.device)
print(f"[CLIP LAZY] Step 4/4 β CLIP ready on {self.device} β total {time.time()-_t0:.1f}s", flush=True)
self._loaded = True
except Exception as e:
print(f"[WARN] Failed to load CLIP model '{model_name}': {e}", flush=True)
fallback_name = "openai/clip-vit-base-patch32"
try:
_t0 = time.time()
print(f"[CLIP LAZY] Fallback 1/2 β Loading: {fallback_name}...", flush=True)
import os
HF_TOKEN = os.getenv("HF_TOKEN")
self.model = CLIPModel.from_pretrained(fallback_name, token=HF_TOKEN)
self.processor = CLIPProcessor.from_pretrained(fallback_name, token=HF_TOKEN)
print(f"[CLIP LAZY] Fallback 2/2 β Moving to {self.device}...", flush=True)
self.model.to(self.device)
self.model.eval()
log_model_device("CLIP script classifier (fallback)", self.device)
print(f"[CLIP LAZY] Fallback CLIP ready β total {time.time()-_t0:.1f}s", flush=True)
self._loaded = True
except Exception as fe:
print(f"[ERROR] Failed to load fallback CLIP model: {fe}", flush=True)
@property
def pipeline(self):
"""Property checked in app.py/test.py to ensure model is initialized"""
return self.model if self.model is not None else None
@property
def is_loaded(self):
"""Check if model has been lazily loaded yet."""
return self._loaded
def classify_script_type(self, image):
"""Classify script type of image into one of the four supported categories"""
self._ensure_loaded()
if not self.pipeline:
return "unknown", 0.0
try:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Prompts representing the four classes
scripts = ["egyptian", "greek", "latin", "cuneiform"]
descriptions = [
"ancient Egyptian hieroglyphic writing with drawings of animals and humans",
"ancient Greek alphabet script on papyrus or stone with polytonic symbols",
"medieval Latin manuscript text written in ink on parchment",
"ancient Mesopotamian cuneiform tablet with wedge-shaped markings in clay"
]
inputs = self.processor(
text=descriptions,
images=image,
return_tensors="pt",
padding=True
).to(self.device)
with torch.inference_mode():
outputs = self.model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
best_idx = np.argmax(probs)
score = float(probs[best_idx])
script_label = scripts[best_idx]
print(f"[INFO] CLIP script classification: {script_label} ({score:.3f})")
return script_label, score
except Exception as e:
print(f"[ERROR] CLIP script classification failed: {e}")
return "unknown", 0.0
def classify_symbols(self, crops, candidate_labels):
"""Classify segmented symbol image crops against candidate labels"""
self._ensure_loaded()
if not self.pipeline or not crops or not candidate_labels:
return [None] * len(crops) if crops else []
try:
print(f"[INFO] Batch classifying {len(crops)} crops using CLIP...")
# Format candidate labels into descriptive prompts for better visual matching
prompts = [f"an ancient Egyptian hieroglyph symbol of a {label.replace('_', ' ')}" for label in candidate_labels]
# Tokenize prompts once
text_inputs = self.processor(
text=prompts,
return_tensors="pt",
padding=True
).to(self.device)
with torch.inference_mode():
text_features = self.model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
results = []
# Process crops (images)
for crop in crops:
if isinstance(crop, np.ndarray):
crop = Image.fromarray(crop)
image_inputs = self.processor(images=crop, return_tensors="pt").to(self.device)
with torch.inference_mode():
image_features = self.model.get_image_features(**image_inputs)
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
# Compute cosine similarities
similarities = (image_features @ text_features.T).squeeze(0)
best_idx = torch.argmax(similarities).item()
results.append(candidate_labels[best_idx])
return results
except Exception as e:
print(f"[ERROR] CLIP symbol classification failed: {e}")
return [candidate_labels[0]] * len(crops) if candidate_labels else [None] * len(crops)
|