File size: 16,616 Bytes
5843fdb
 
 
ffa73e4
5843fdb
 
 
 
ffa73e4
5843fdb
 
ffa73e4
5843fdb
 
 
 
ffa73e4
 
 
 
 
 
 
5843fdb
 
 
 
 
 
ffa73e4
5843fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffa73e4
5843fdb
ffa73e4
 
5843fdb
ffa73e4
5843fdb
 
ffa73e4
5843fdb
 
 
 
 
 
 
 
 
 
 
ffa73e4
 
 
 
 
 
 
5843fdb
 
 
ffa73e4
 
53532e2
ffa73e4
 
df8c6ad
3622663
04cf1bc
 
105b065
ffa73e4
 
 
 
3622663
 
 
ffa73e4
5843fdb
 
ffa73e4
 
 
04cf1bc
 
 
 
 
 
 
ffa73e4
 
04cf1bc
 
 
 
 
ffa73e4
04cf1bc
ffa73e4
3622663
 
 
 
04cf1bc
ffa73e4
 
04cf1bc
 
 
3622663
 
04cf1bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adc8711
 
 
faec259
04cf1bc
 
ffa73e4
 
 
 
 
04cf1bc
 
 
 
 
 
 
 
 
 
ffa73e4
 
04cf1bc
ffa73e4
 
 
 
04cf1bc
ffa73e4
 
 
 
 
 
 
 
 
fd3ac75
 
5843fdb
ffa73e4
5843fdb
ffa73e4
5843fdb
 
ffa73e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5843fdb
 
 
 
ffa73e4
5843fdb
ffa73e4
 
 
 
5843fdb
ffa73e4
 
 
5843fdb
 
 
fd3ac75
 
 
 
 
 
5843fdb
 
 
ffa73e4
faec259
53532e2
5843fdb
 
 
ffa73e4
faec259
4bb0230
 
 
 
 
 
 
 
 
 
 
faec259
 
 
 
 
 
 
 
 
4bb0230
faec259
 
 
5843fdb
 
ffa73e4
 
 
 
 
faec259
5843fdb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import os
import json
import traceback
from typing import Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, AutoModel
import gradio as gr

# --- Device Setup ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# For 8-bit models, the vision dtype is handled by bitsandbytes
# We still need HEAD_DTYPE for our classifier head
HEAD_DTYPE = torch.float32

# --- DINOv3 Specific Constants ---
DINOV3_PATCH_SIZE = 16
MAX_DINOV3_RESOLUTION = 4096

print(f"Using device: {DEVICE}")
print(f"Head model dtype: {HEAD_DTYPE}")


# --- Model Definitions (Copied from hybrid_model.py) ---
# (RMSNorm, SwiGLUFFN, ResBlockRMS, HybridHeadModel classes are unchanged and go here)
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps
    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class SwiGLUFFN(nn.Module):
    def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, act_layer: nn.Module = nn.SiLU, dropout: float = 0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or int(in_features * 8 / 3 / 2 * 2 )
        hidden_features = (hidden_features + 1) // 2 * 2
        self.w12 = nn.Linear(in_features, hidden_features * 2, bias=False)
        self.act = act_layer()
        self.dropout1 = nn.Dropout(dropout)
        self.w3 = nn.Linear(hidden_features, out_features, bias=False)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, x):
        gate_val, up_val = self.w12(x).chunk(2, dim=-1)
        x = self.dropout1(self.act(gate_val) * up_val)
        x = self.dropout2(self.w3(x))
        return x

class ResBlockRMS(nn.Module):
    def __init__(self, ch: int, dropout: float = 0.0, rms_norm_eps: float = 1e-6):
        super().__init__()
        self.norm = RMSNorm(ch, eps=rms_norm_eps)
        self.ffn = SwiGLUFFN(in_features=ch, dropout=dropout)
    def forward(self, x):
        return x + self.ffn(self.norm(x))

class HybridHeadModel(nn.Module):
    def __init__(self, features: int, hidden_dim: int = 1280, num_classes: int = 2, use_attention: bool = True,
                 num_attn_heads: int = 16, attn_dropout: float = 0.1, num_res_blocks: int = 3,
                 dropout_rate: float = 0.1, rms_norm_eps: float = 1e-6, output_mode: str = 'linear'):
        super().__init__()
        self.features = features; self.hidden_dim = hidden_dim; self.num_classes = num_classes
        self.use_attention = use_attention; self.output_mode = output_mode.lower()
        self.attention = None; self.norm_attn = None
        if self.use_attention:
            actual_num_heads = num_attn_heads
            if features % num_attn_heads != 0:
                possible_heads = [h for h in [1, 2, 4, 8, 16, 32] if features % h == 0] # Expanded list
                if not possible_heads: actual_num_heads = 1
                else: actual_num_heads = min(possible_heads, key=lambda x: abs(x-num_attn_heads))
                if actual_num_heads != num_attn_heads: print(f"HybridHead Warning: Adjusting heads {num_attn_heads}->{actual_num_heads} for features={features}")
            self.attention = nn.MultiheadAttention(features, actual_num_heads, dropout=attn_dropout, batch_first=True, bias=True)
            self.norm_attn = RMSNorm(features, eps=rms_norm_eps)
        mlp_layers = [nn.Linear(features, hidden_dim), RMSNorm(hidden_dim, eps=rms_norm_eps)]
        for _ in range(num_res_blocks): mlp_layers.append(ResBlockRMS(hidden_dim, dropout=dropout_rate, rms_norm_eps=rms_norm_eps))
        mlp_layers.append(RMSNorm(hidden_dim, eps=rms_norm_eps))
        down_proj_hidden = hidden_dim // 2
        mlp_layers.append(SwiGLUFFN(hidden_dim, hidden_features=down_proj_hidden, out_features=down_proj_hidden, dropout=dropout_rate))
        mlp_layers.append(RMSNorm(down_proj_hidden, eps=rms_norm_eps))
        mlp_layers.append(nn.Linear(down_proj_hidden, num_classes))
        self.mlp_head = nn.Sequential(*mlp_layers)

    def forward(self, x: torch.Tensor):
        if self.use_attention and self.attention is not None:
            x_seq = x.unsqueeze(1); attn_output, _ = self.attention(x_seq, x_seq, x_seq); x = self.norm_attn(x + attn_output.squeeze(1))
        logits = self.mlp_head(x.to(HEAD_DTYPE))
        output_mode = self.output_mode
        if output_mode == 'linear': output = logits
        elif output_mode == 'sigmoid': output = torch.sigmoid(logits)
        elif output_mode == 'softmax': output = F.softmax(logits, dim=-1)
        elif output_mode == 'tanh_scaled': output = (torch.tanh(logits) + 1.0) / 2.0
        else: raise RuntimeError(f"Invalid output_mode '{output_mode}'.")
        if self.num_classes == 1 and output.ndim == 2 and output.shape[1] == 1: output = output.squeeze(-1)
        return output

# --- Model Catalog ---
MODEL_CATALOG = {
    "AnatomyFlaws-v15.5 (DINOv3 7b bf16)": {
        "repo_id": "Enferlain/lumi-classifier",
        "config_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl.config.json",
        "head_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl_s3K_best_val.safetensors",
        # Explicitly define the vision model repo ID to prevent errors
        # "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-8bit" bnb 8bit
        # "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-int4", int4 
        "vision_model_repo_id": "PIA-SPACE-LAB/dinov3-vit7b16-pretrain-lvd1689m",
    },
    "AnatomyFlaws-v14.7 (SigLIP naflex)": {
        "repo_id": "Enferlain/lumi-classifier",
        "config_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670.config.json",
        "head_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670_s2K.safetensors",
        # The base SigLIP model is not custom, so we use its official ID
        "vision_model_repo_id": "google/siglip2-so400m-patch16-naflex"
    },
}

# --- Model Manager Class ---
class ModelManager:
    def __init__(self, catalog: Dict[str, Dict[str, str]]):
        self.catalog = catalog
        self.current_model_name: str = None
        self.vision_model: nn.Module = None
        self.hf_processor: Any = None
        self.head_model: HybridHeadModel = None
        self.labels: Dict[int, str] = None
        self.config: Dict[str, Any] = None

    def load_model(self, model_name: str):
        if model_name == self.current_model_name:
            return
        if model_name not in self.catalog:
            raise ValueError(f"Model '{model_name}' not found.")

        print(f"Switching to model: {model_name}...")

        model_info = self.catalog[model_name]
        repo_id = model_info["repo_id"]
        config_filename = model_info["config_filename"]
        head_filename = model_info["head_filename"]
        vision_model_repo_id = model_info["vision_model_repo_id"]

        try:
            config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
            with open(config_path, 'r', encoding='utf-8') as f:
                self.config = json.load(f)

            print(f"Loading vision model: {vision_model_repo_id}")
            self.hf_processor = AutoProcessor.from_pretrained(vision_model_repo_id, trust_remote_code=True)

            # --- UPDATED: CPU-compatible loading logic ---
            if DEVICE == "cpu":
                # For CPU, load unquantized model with BF16 (original format)
                print("Loading unquantized model for CPU...")
                try:
                    self.vision_model = AutoModel.from_pretrained(
                        vision_model_repo_id,
                        torch_dtype=torch.bfloat16,  # Keep original BF16 format
                        device_map={"": "cpu"},      # Force CPU device mapping
                        trust_remote_code=True
                    ).eval()
                    print("Successfully loaded model in BF16 format.")
                except Exception as bf16_error:
                    print(f"BF16 loading failed: {bf16_error}")
                    print("Falling back to FP32...")
                    self.vision_model = AutoModel.from_pretrained(
                        vision_model_repo_id,
                        torch_dtype=torch.float32,  # Fallback to FP32
                        device_map={"": "cpu"},
                        trust_remote_code=True
                    ).eval()
                    print("Successfully loaded model in FP32 format.")
            else:
                # For GPU environments (unchanged)
                self.vision_model = AutoModel.from_pretrained(
                    vision_model_repo_id,
                    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
                ).to(DEVICE).eval()

            # Load classifier head (unchanged)
            head_model_path = hf_hub_download(repo_id=repo_id, filename=head_filename)
            print(f"Loading head model: {head_filename}")
            state_dict = load_file(head_model_path, device='cpu')
            head_params = self.config.get("predictor_params", self.config)
            self.head_model = HybridHeadModel(
                features=head_params.get("features"), 
                hidden_dim=head_params.get("hidden_dim"),
                num_classes=self.config.get("num_classes"), 
                use_attention=head_params.get("use_attention"),
                num_attn_heads=head_params.get("num_attn_heads"), 
                attn_dropout=head_params.get("attn_dropout"),
                num_res_blocks=head_params.get("num_res_blocks"), 
                dropout_rate=head_params.get("dropout_rate"),
                output_mode=head_params.get("output_mode", "linear")
            )
            self.head_model.load_state_dict(state_dict, strict=True)
            self.head_model.to(DEVICE).eval()

            raw_labels = self.config.get("labels", {'0': 'Bad', '1': 'Good'})
            self.labels = {int(k): (v['name'] if isinstance(v, dict) else v) for k, v in raw_labels.items()}
            self.current_model_name = model_name
            print(f"Successfully loaded '{model_name}'.")

        except Exception as e:
            self.current_model_name = None
            raise RuntimeError(f"Failed to load model '{model_name}': {e}\n{traceback.format_exc()}")

# --- Global Model Manager Instance ---
model_manager = ModelManager(MODEL_CATALOG)

# --- Prediction Function (v3 from before) ---
def predict_anatomy_v3(image: Image.Image, model_name: str):
    if image is None:
        return {"Error": 1.0, "Info": 0.0}  # Return numeric values
    try:
        model_manager.load_model(model_name)
        pil_image = image.convert("RGB")
        emb = None

        with torch.no_grad():
            base_model_type = model_manager.config.get("base_vision_model", "")
            if "dinov3" in base_model_type.lower():
                current_w, current_h = pil_image.size
                img_to_process = pil_image
                if max(current_w, current_h) > MAX_DINOV3_RESOLUTION:
                    scale = MAX_DINOV3_RESOLUTION / max(current_w, current_h)
                    current_w, current_h = int(current_w * scale), int(current_h * scale)
                    img_to_process = pil_image.resize((current_w, current_h), Image.Resampling.LANCZOS)
                new_w = ((current_w + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE
                new_h = ((current_h + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE
                if new_w != current_w or new_h != current_h:
                    img_to_process = img_to_process.resize((new_w, new_h), Image.Resampling.LANCZOS)
                inputs = model_manager.hf_processor(images=[img_to_process], return_tensors="pt")
                # For 8-bit, send inputs to the same device as the model
                pixel_values = inputs.pixel_values.to(model_manager.vision_model.device)
                outputs = model_manager.vision_model(pixel_values=pixel_values)
                last_hidden_state = outputs.last_hidden_state
                nreg = getattr(model_manager.vision_model.config, 'num_register_tokens', 0)
                patch_embeddings = last_hidden_state[:, 1 + nreg:]
                emb = torch.mean(patch_embeddings, dim=1)
            elif "siglip" in base_model_type.lower():
                inputs = model_manager.hf_processor(images=[pil_image], return_tensors="pt")
                pixel_values = inputs.get("pixel_values").to(device=DEVICE, dtype=torch.float16)
                if "naflex" in base_model_type.lower():
                    attention_mask = inputs.get("pixel_attention_mask").to(device=DEVICE)
                    spatial_shapes = inputs.get("spatial_shapes")
                    model_call_kwargs = {"pixel_values": pixel_values, "attention_mask": attention_mask,
                                         "spatial_shapes": torch.tensor(spatial_shapes, dtype=torch.long).to(DEVICE)}
                    vision_model_component = getattr(model_manager.vision_model, 'vision_model', model_manager.vision_model)
                    emb = vision_model_component(**model_call_kwargs).pooler_output
                else: emb = model_manager.vision_model.get_image_features(pixel_values=pixel_values)
            else: raise ValueError(f"Unknown base model type for embedding: {base_model_type}")
            if emb is None: raise ValueError("Failed to get embedding.")
            norm = torch.linalg.norm(emb.float(), dim=-1, keepdim=True).clamp(min=1e-8)
            emb_normalized = emb / norm.to(emb.dtype)
        with torch.no_grad():
            prediction = model_manager.head_model(emb_normalized.to(DEVICE, dtype=HEAD_DTYPE))
        output_probs = {}
        if model_manager.head_model.num_classes == 2:
            probs = F.softmax(prediction.squeeze().float(), dim=-1)
            output_probs[model_manager.labels[0]] = probs[0].item()
            output_probs[model_manager.labels[1]] = probs[1].item()
        else:
            prob_good = torch.sigmoid(prediction.squeeze()).item()
            output_probs[model_manager.labels[0]] = 1.0 - prob_good
            output_probs[model_manager.labels[1]] = prob_good
        return output_probs
    except Exception as e:
        print(f"Error during prediction: {e}\n{traceback.format_exc()}")
        # Return properly formatted error for Gradio Label
        error_msg = str(e)[:50] + "..." if len(str(e)) > 50 else str(e)
        return {
            f"Error: {error_msg}": 1.0,
            "Please check logs": 0.0
        }

# --- Gradio Interface ---
DESCRIPTION = """
## Lumi's Anatomy Flaw Classifier Demo ✨
Select a model from the dropdown, then upload an image to classify its anatomy/structural correctness.
Will be slow since it runs on cpu, ~2minutes on dinov3.
"""
EXAMPLE_DIR = "examples"

default_model = list(MODEL_CATALOG.keys())[0]

# 1. Find the paths to our example images
example_paths = []
if os.path.isdir(EXAMPLE_DIR):
    example_paths = [os.path.join(EXAMPLE_DIR, fname) for fname in sorted(os.listdir(EXAMPLE_DIR)) if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]

# 2. Create the nested list Gradio needs: [[image, model_name], [image, model_name], ...]
examples_nested = []
if example_paths:
    examples_nested = [[path, default_model] for path in example_paths]

# 3. Create the interface, passing the correctly formatted list
interface = gr.Interface(
    fn=predict_anatomy_v3,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Dropdown(choices=list(MODEL_CATALOG.keys()), value=default_model, label="Classifier Model")
    ],
    outputs=gr.Label(label="Class Probabilities", num_top_classes=2),
    title="Lumi's Anatomy Classifier",
    description=DESCRIPTION,
    examples=examples_nested if examples_nested else None, # Pass the new nested list
    allow_flagging="never",
    cache_examples=True
)

if __name__ == "__main__":
    try:
        print("Pre-loading default model...")
        model_manager.load_model(default_model)
    except Exception as e:
        print(f"WARNING: Could not pre-load default model. Error: {e}")
        
    interface.launch()