File size: 6,972 Bytes
7b9f40a
2f4af3f
 
 
 
 
 
 
7b9f40a
2f4af3f
 
 
 
 
 
7b9f40a
 
 
 
 
 
 
 
2f4af3f
 
7b9f40a
 
36331c6
 
 
7b9f40a
 
 
36331c6
7b9f40a
 
 
2f4af3f
7b9f40a
2f4af3f
7b9f40a
 
 
2f4af3f
7b9f40a
2f4af3f
 
7b9f40a
 
36331c6
 
 
 
7b9f40a
 
2f4af3f
7b9f40a
2f4af3f
7b9f40a
 
2f4af3f
7b9f40a
2f4af3f
 
 
 
 
 
7b9f40a
 
 
 
 
2f4af3f
 
7b9f40a
 
2f4af3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b9f40a
 
2f4af3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import time
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
from config import Config
from utils.gpu_diagnostics import log_model_device


class CLIPClassifier:
    def __init__(self):
        self.config = Config()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.processor = None
        self._loaded = False
        print("[INFO] CLIPClassifier created (lazy β€” model will load on first use)", flush=True)

    def _ensure_loaded(self):
        """Lazily load CLIP model and processor on first use, with fallback."""
        if self._loaded:
            return

        model_name = getattr(self.config, 'CLIP_MODEL', 'openai/clip-vit-base-patch32')
        try:
            _t0 = time.time()
            print(f"[CLIP LAZY] Step 1/4 β€” Loading CLIPModel: {model_name}...", flush=True)
            import os
            HF_TOKEN = os.getenv("HF_TOKEN")
            self.model = CLIPModel.from_pretrained(model_name, token=HF_TOKEN)
            print(f"[CLIP LAZY] Step 2/4 β€” CLIPModel loaded in {time.time()-_t0:.1f}s. Loading CLIPProcessor...", flush=True)

            _t1 = time.time()
            self.processor = CLIPProcessor.from_pretrained(model_name, token=HF_TOKEN)
            print(f"[CLIP LAZY] Step 3/4 β€” CLIPProcessor loaded in {time.time()-_t1:.1f}s. Moving to {self.device}...", flush=True)

            _t2 = time.time()
            self.model.to(self.device)
            self.model.eval()
            log_model_device("CLIP script classifier", self.device)
            print(f"[CLIP LAZY] Step 4/4 β€” CLIP ready on {self.device} β€” total {time.time()-_t0:.1f}s", flush=True)
            self._loaded = True

        except Exception as e:
            print(f"[WARN] Failed to load CLIP model '{model_name}': {e}", flush=True)
            fallback_name = "openai/clip-vit-base-patch32"
            try:
                _t0 = time.time()
                print(f"[CLIP LAZY] Fallback 1/2 β€” Loading: {fallback_name}...", flush=True)
                import os
                HF_TOKEN = os.getenv("HF_TOKEN")
                self.model = CLIPModel.from_pretrained(fallback_name, token=HF_TOKEN)
                self.processor = CLIPProcessor.from_pretrained(fallback_name, token=HF_TOKEN)
                print(f"[CLIP LAZY] Fallback 2/2 β€” Moving to {self.device}...", flush=True)

                self.model.to(self.device)
                self.model.eval()
                log_model_device("CLIP script classifier (fallback)", self.device)
                print(f"[CLIP LAZY] Fallback CLIP ready β€” total {time.time()-_t0:.1f}s", flush=True)
                self._loaded = True
            except Exception as fe:
                print(f"[ERROR] Failed to load fallback CLIP model: {fe}", flush=True)

    @property
    def pipeline(self):
        """Property checked in app.py/test.py to ensure model is initialized"""
        return self.model if self.model is not None else None

    @property
    def is_loaded(self):
        """Check if model has been lazily loaded yet."""
        return self._loaded

    def classify_script_type(self, image):
        """Classify script type of image into one of the four supported categories"""
        self._ensure_loaded()

        if not self.pipeline:
            return "unknown", 0.0

        try:
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)

            # Prompts representing the four classes
            scripts = ["egyptian", "greek", "latin", "cuneiform"]
            descriptions = [
                "ancient Egyptian hieroglyphic writing with drawings of animals and humans",
                "ancient Greek alphabet script on papyrus or stone with polytonic symbols",
                "medieval Latin manuscript text written in ink on parchment",
                "ancient Mesopotamian cuneiform tablet with wedge-shaped markings in clay"
            ]

            inputs = self.processor(
                text=descriptions,
                images=image,
                return_tensors="pt",
                padding=True
            ).to(self.device)

            with torch.inference_mode():
                outputs = self.model(**inputs)
                logits_per_image = outputs.logits_per_image
                probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]

            best_idx = np.argmax(probs)
            score = float(probs[best_idx])
            script_label = scripts[best_idx]

            print(f"[INFO] CLIP script classification: {script_label} ({score:.3f})")
            return script_label, score

        except Exception as e:
            print(f"[ERROR] CLIP script classification failed: {e}")
            return "unknown", 0.0

    def classify_symbols(self, crops, candidate_labels):
        """Classify segmented symbol image crops against candidate labels"""
        self._ensure_loaded()

        if not self.pipeline or not crops or not candidate_labels:
            return [None] * len(crops) if crops else []

        try:
            print(f"[INFO] Batch classifying {len(crops)} crops using CLIP...")
            
            # Format candidate labels into descriptive prompts for better visual matching
            prompts = [f"an ancient Egyptian hieroglyph symbol of a {label.replace('_', ' ')}" for label in candidate_labels]
            
            # Tokenize prompts once
            text_inputs = self.processor(
                text=prompts,
                return_tensors="pt",
                padding=True
            ).to(self.device)
            
            with torch.inference_mode():
                text_features = self.model.get_text_features(**text_inputs)
                text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)

            results = []
            # Process crops (images)
            for crop in crops:
                if isinstance(crop, np.ndarray):
                    crop = Image.fromarray(crop)
                
                image_inputs = self.processor(images=crop, return_tensors="pt").to(self.device)
                
                with torch.inference_mode():
                    image_features = self.model.get_image_features(**image_inputs)
                    image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
                    
                    # Compute cosine similarities
                    similarities = (image_features @ text_features.T).squeeze(0)
                    best_idx = torch.argmax(similarities).item()
                    
                results.append(candidate_labels[best_idx])
                
            return results

        except Exception as e:
            print(f"[ERROR] CLIP symbol classification failed: {e}")
            return [candidate_labels[0]] * len(crops) if candidate_labels else [None] * len(crops)