aneeshm44 commited on
Commit
74f22cb
·
verified ·
1 Parent(s): d8a9fe7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +467 -0
app.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
+ import json
5
+ import math
6
+ import timm
7
+ import einops
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ import pandas as pd
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from PIL import Image
15
+ import gradio as gr
16
+ from huggingface_hub import snapshot_download
17
+ from typing import List, Union, Dict
18
+
19
+
20
+ # Vision Model
21
+ class TimmCNNModel(nn.Module):
22
+ def __init__(self, num_classes: int = 8, model_name: str = "efficientnet_b0"):
23
+ super().__init__()
24
+
25
+ self.backbone = timm.create_model(
26
+ 'efficientnet_b0',
27
+ pretrained=True,
28
+ num_classes=0,
29
+ )
30
+
31
+ self.feature_dim = self.backbone.num_features
32
+
33
+ self.classifier = nn.Sequential(
34
+ nn.Dropout(0.1),
35
+ nn.Linear(self.feature_dim, 512),
36
+ nn.ReLU(inplace=True),
37
+ nn.BatchNorm1d(512),
38
+ nn.Dropout(0.1),
39
+ nn.Linear(512, 256),
40
+ nn.ReLU(inplace=True),
41
+ nn.Linear(256, num_classes)
42
+ )
43
+
44
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
45
+ return self.backbone(x)
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ features = self.forward_features(x)
49
+ logits = self.classifier(features)
50
+ return logits
51
+
52
+ # Projector Model
53
+ class Projector_4to3d(nn.Module):
54
+ def __init__(self, cnn_dim: int = 1280, llm_dim: int = 2048, num_heads: int = 8, dropout: float = 0.1):
55
+ super().__init__()
56
+ self.cnn_dim = cnn_dim
57
+ self.llm_dim = llm_dim
58
+
59
+ # Spatial positional embeddings for 8x8 grid
60
+ self.spatial_pos_embed = nn.Parameter(torch.randn(64, cnn_dim))
61
+
62
+ # Multi-scale feature processing
63
+ self.spatial_conv = nn.Conv2d(cnn_dim, cnn_dim // 2, 1)
64
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
65
+
66
+ # Enhanced projection layers
67
+ self.input_proj = nn.Sequential(
68
+ nn.Linear(cnn_dim, llm_dim),
69
+ nn.LayerNorm(llm_dim),
70
+ nn.ReLU(),
71
+ nn.Dropout(dropout)
72
+ )
73
+
74
+ # Multi-head self-attention for spatial reasoning
75
+ self.spatial_attention = nn.MultiheadAttention(
76
+ embed_dim=llm_dim,
77
+ num_heads=num_heads,
78
+ dropout=dropout,
79
+ batch_first=True
80
+ )
81
+
82
+ # Cross-attention for text-image alignment
83
+ self.cross_attention = nn.MultiheadAttention(
84
+ embed_dim=llm_dim,
85
+ num_heads=num_heads,
86
+ dropout=dropout,
87
+ batch_first=True
88
+ )
89
+
90
+ self.norm1 = nn.LayerNorm(llm_dim)
91
+ self.norm2 = nn.LayerNorm(llm_dim)
92
+
93
+ # Enhanced FFN
94
+ self.ffn = nn.Sequential(
95
+ nn.Linear(llm_dim, llm_dim * 4),
96
+ nn.GELU(),
97
+ nn.Dropout(dropout),
98
+ nn.Linear(llm_dim * 4, llm_dim),
99
+ nn.Dropout(dropout)
100
+ )
101
+
102
+ self.norm3 = nn.LayerNorm(llm_dim)
103
+
104
+ # Token compression layer
105
+ self.compress_tokens = nn.Parameter(torch.randn(32, llm_dim))
106
+ self.token_compression = nn.MultiheadAttention(
107
+ embed_dim=llm_dim,
108
+ num_heads=num_heads,
109
+ dropout=dropout,
110
+ batch_first=True
111
+ )
112
+
113
+ self._init_weights()
114
+
115
+ def _init_weights(self):
116
+ for module in self.modules():
117
+ if isinstance(module, nn.Linear):
118
+ nn.init.xavier_uniform_(module.weight)
119
+ if module.bias is not None:
120
+ nn.init.zeros_(module.bias)
121
+ elif isinstance(module, nn.LayerNorm):
122
+ nn.init.ones_(module.weight)
123
+ nn.init.zeros_(module.bias)
124
+ elif isinstance(module, nn.Conv2d):
125
+ nn.init.kaiming_normal_(module.weight)
126
+
127
+ def forward(self, cnn_features: torch.Tensor, text_embeddings: torch.Tensor = None) -> torch.Tensor:
128
+ batch_size = cnn_features.shape[0]
129
+
130
+ # Multi-scale processing
131
+ spatial_features = self.spatial_conv(cnn_features)
132
+ global_context = self.global_pool(cnn_features).flatten(1)
133
+
134
+ # Flatten spatial features and add positional encoding
135
+ x = einops.rearrange(cnn_features, "b c h w -> b (h w) c")
136
+ pos_embeddings = self.spatial_pos_embed.unsqueeze(0).expand(batch_size, -1, -1)
137
+ x = x + pos_embeddings
138
+
139
+ # Project to LLM dimension
140
+ x = self.input_proj(x)
141
+
142
+ # Self-attention for spatial reasoning
143
+ attended_x, spatial_attn_weights = self.spatial_attention(x, x, x)
144
+ x = self.norm1(x + attended_x)
145
+
146
+ # Cross-attention with text (if available)
147
+ if text_embeddings is not None:
148
+ text_embeddings_float = text_embeddings.float()
149
+ cross_attended, cross_attn_weights = self.cross_attention(x, text_embeddings_float, text_embeddings_float)
150
+ x = self.norm2(x + cross_attended)
151
+
152
+ # FFN
153
+ ffn_out = self.ffn(x)
154
+ x = self.norm3(x + ffn_out)
155
+
156
+ # Optional token compression
157
+ compress_queries = self.compress_tokens.unsqueeze(0).expand(batch_size, -1, -1)
158
+ compressed_x, _ = self.token_compression(compress_queries, x, x)
159
+
160
+ return compressed_x
161
+
162
+ # Main VLM Model
163
+ class Model(nn.Module):
164
+ def __init__(self, image_model, language_model, projector, tokenizer, prompt="Describe this image:"):
165
+ super().__init__()
166
+ self.image_model = image_model
167
+ self.language_model = language_model
168
+ self.projector = projector
169
+ self.tokenizer = tokenizer
170
+ self.eos_token = tokenizer.eos_token
171
+ self.prompt = prompt
172
+
173
+ device = next(self.language_model.parameters()).device
174
+
175
+ self.image_model.to(device)
176
+ self.projector.to(device)
177
+
178
+ # Create prompt embeddings
179
+ prompt_tokens = tokenizer(text=prompt, return_tensors="pt").input_ids.to(device)
180
+ prompt_embeddings = language_model.get_input_embeddings()(prompt_tokens).detach()
181
+ self.register_buffer('prompt_embeddings', prompt_embeddings)
182
+
183
+ @property
184
+ def device(self):
185
+ return next(self.parameters()).device
186
+
187
+ def generate(self, patches: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
188
+ device = self.device
189
+ patches = patches.to(device)
190
+
191
+ image_features = self.image_model.backbone.forward_features(patches)
192
+ patch_embeddings = self.projector(image_features)
193
+ patch_embeddings = patch_embeddings.to(torch.bfloat16)
194
+
195
+ embeddings = torch.cat([
196
+ self.prompt_embeddings.expand(patches.size(0), -1, -1),
197
+ patch_embeddings,
198
+ ], dim=1)
199
+
200
+ prompt_mask = torch.ones(patches.size(0), self.prompt_embeddings.size(1), device=device)
201
+ patch_mask = torch.ones(patches.size(0), patch_embeddings.size(1), device=device)
202
+ attention_mask = torch.cat([prompt_mask, patch_mask], dim=1)
203
+
204
+ return self.language_model.generate(
205
+ inputs_embeds=embeddings,
206
+ attention_mask=attention_mask,
207
+ **generator_kwargs
208
+ )
209
+
210
+ vlm_model = None
211
+ tokenizer = None
212
+
213
+ def download_and_load_models():
214
+ global vlm_model, tokenizer
215
+
216
+ print("Starting model download and initialization...")
217
+
218
+ if torch.cuda.is_available():
219
+ device = torch.device("cuda:0")
220
+ print("CUDA available - using GPU")
221
+ else:
222
+ device = torch.device("cpu")
223
+ print("CUDA not available - using CPU")
224
+
225
+ repo_id = "aneeshm44/regfinal"
226
+ print(f"Downloading from repo: {repo_id}")
227
+
228
+ local_dir = tempfile.mkdtemp(prefix="regfinal_")
229
+ print(f"Local directory: {local_dir}")
230
+
231
+ try:
232
+ snapshot_download(
233
+ repo_id=repo_id,
234
+ repo_type="dataset",
235
+ local_dir=local_dir,
236
+ allow_patterns=[
237
+ "llmweights/*",
238
+ "imagemodelweights/finalcheckpoint.pth",
239
+ "projectorweights/projector.pth"
240
+ ],
241
+ local_dir_use_symlinks=False,
242
+ )
243
+ print("Download completed successfully")
244
+ except Exception as e:
245
+ print(f"Download failed: {e}")
246
+ raise e
247
+
248
+
249
+ llm_path = os.path.join(local_dir, "llmweights")
250
+ image_weights_path = os.path.join(local_dir, "imagemodelweights", "finalcheckpoint.pth")
251
+ projector_weights_path = os.path.join(local_dir, "projectorweights", "projector.pth")
252
+
253
+ print("Loading language model...")
254
+ try:
255
+ language_model = AutoModelForCausalLM.from_pretrained(
256
+ llm_path,
257
+ trust_remote_code=True,
258
+ torch_dtype=torch.bfloat16,
259
+ low_cpu_mem_usage=True,
260
+ )
261
+ language_model.eval()
262
+ language_model.to(device)
263
+
264
+ tokenizer = AutoTokenizer.from_pretrained(llm_path)
265
+ print("Language model loaded successfully")
266
+ except Exception as e:
267
+ print(f"Language model loading failed: {e}")
268
+ raise e
269
+
270
+ print("Loading vision model...")
271
+ try:
272
+ image_model = TimmCNNModel(num_classes=8)
273
+ weights = torch.load(image_weights_path, map_location=device)
274
+ image_model.load_state_dict(weights['model_state_dict'])
275
+
276
+ for param in image_model.parameters():
277
+ param.requires_grad = False
278
+ image_model.eval()
279
+ image_model.to(device)
280
+ print("Vision model loaded successfully")
281
+ except Exception as e:
282
+ print(f"Vision model loading failed: {e}")
283
+ raise e
284
+
285
+ print("Loading projector...")
286
+ try:
287
+ projector = Projector_4to3d(cnn_dim=1280, llm_dim=2048, num_heads=8)
288
+ weights = torch.load(projector_weights_path, map_location=device)
289
+ projector.load_state_dict(weights)
290
+
291
+ for param in projector.parameters():
292
+ param.requires_grad = False
293
+ projector.eval()
294
+ projector.to(device)
295
+ print("Projector loaded successfully")
296
+ except Exception as e:
297
+ print(f"Projector loading failed: {e}")
298
+ raise e
299
+
300
+ print("Creating VLM model...")
301
+ try:
302
+ vlm_model = Model(image_model, language_model, projector, tokenizer, prompt="Describe this image:")
303
+ vlm_model = vlm_model.to(device)
304
+ print("VLM model created successfully")
305
+ except Exception as e:
306
+ print(f"VLM model creation failed: {e}")
307
+ raise e
308
+
309
+ print("All models loaded successfully!")
310
+
311
+ def tensor_to_pil_image(tensor):
312
+ """Convert tensor to PIL image for display"""
313
+ # Denormalize the tensor
314
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
315
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
316
+
317
+ # Remove batch dimension and denormalize
318
+ img_tensor = tensor.squeeze(0)
319
+ img_tensor = img_tensor * std + mean
320
+ img_tensor = torch.clamp(img_tensor, 0, 1)
321
+
322
+ # Convert to PIL
323
+ img_array = img_tensor.permute(1, 2, 0).numpy()
324
+ img_array = (img_array * 255).astype(np.uint8)
325
+ return Image.fromarray(img_array)
326
+
327
+ def describe_image(image, temperature, top_p, max_tokens):
328
+ """Generate description for uploaded image"""
329
+ global vlm_model, tokenizer
330
+
331
+ if vlm_model is None:
332
+ return "Models not loaded yet. Please wait for initialization to complete.", None
333
+
334
+ if image is None:
335
+ return "Please upload an image.", None
336
+
337
+ try:
338
+ if isinstance(image, str):
339
+ image = Image.open(image).convert('RGB')
340
+ elif hasattr(image, 'convert'):
341
+ image = image.convert('RGB')
342
+
343
+ image_tensor = torch.tensor(image).unsqueeze(0) # Add batch dimension
344
+
345
+ processed_image = tensor_to_pil_image(image_tensor)
346
+
347
+ # Generation parameters
348
+ generator_kwargs = {
349
+ "max_new_tokens": int(max_tokens),
350
+ "do_sample": True,
351
+ "temperature": float(temperature),
352
+ "top_p": float(top_p),
353
+ "pad_token_id": tokenizer.eos_token_id
354
+ }
355
+
356
+ with torch.no_grad():
357
+ output_ids = vlm_model.generate(image_tensor, generator_kwargs)
358
+ text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
359
+
360
+ if "Describe this image:" in text:
361
+ description = text.split("Describe this image:")[-1].strip()
362
+ else:
363
+ description = text.strip()
364
+
365
+ result_text = description if description else "Unable to generate description."
366
+
367
+ return result_text, processed_image
368
+
369
+ except Exception as e:
370
+ return f"Error processing image: {str(e)}", None
371
+
372
+ def reset_interface():
373
+ return None, "Models loaded successfully! Upload an image to get started.", None
374
+
375
+ try:
376
+ download_and_load_models()
377
+ initial_status = "Models loaded successfully! Upload an image to get started."
378
+ except Exception as e:
379
+ initial_status = f"Failed to load models: {str(e)}"
380
+
381
+ # Gradio Interface
382
+ def create_interface():
383
+ with gr.Blocks(title="WSI Pathology Report using Gemma3n") as demo:
384
+ gr.Markdown("# WSI Pathology Report using Gemma3n")
385
+ gr.Markdown("Upload a pathology image and get an AI-generated pathology report.")
386
+
387
+ with gr.Row():
388
+ with gr.Column():
389
+ image_input = gr.Image(type="pil", label="Upload WSI Image")
390
+
391
+ gr.Markdown("Generation parameters")
392
+ with gr.Row():
393
+ temperature_slider = gr.Slider(
394
+ minimum=0.1,
395
+ maximum=1.0,
396
+ value=0.4,
397
+ step=0.1,
398
+ label="Temperature",
399
+ info="Lower values give more consistent results whereas higher values produce more creative"
400
+ )
401
+
402
+ top_p_slider = gr.Slider(
403
+ minimum=0.1,
404
+ maximum=1.0,
405
+ value=0.9,
406
+ step=0.1,
407
+ label="Top-p",
408
+ info="Lower values = more focused vocabulary, Higher values = more diverse vocabulary"
409
+ )
410
+
411
+ max_tokens_slider = gr.Slider(
412
+ minimum=10,
413
+ maximum=200,
414
+ value=60,
415
+ step=10,
416
+ label="Max Tokens"
417
+ )
418
+
419
+ with gr.Row():
420
+ submit_btn = gr.Button("Generate Report", variant="primary")
421
+ reset_btn = gr.Button("Reset", variant="secondary")
422
+
423
+ with gr.Column():
424
+ output_text = gr.Textbox(
425
+ label="Pathology Report",
426
+ lines=8,
427
+ value=initial_status,
428
+ show_copy_button=True
429
+ )
430
+
431
+ processed_image = gr.Image(
432
+ label="Processed WSI Image",
433
+ show_download_button=True
434
+ )
435
+
436
+ # Event handlers
437
+ submit_btn.click(
438
+ fn=describe_image,
439
+ inputs=[image_input, temperature_slider, top_p_slider, max_tokens_slider],
440
+ outputs=[output_text, processed_image]
441
+ )
442
+
443
+ # Auto-generate on image upload
444
+ image_input.change(
445
+ fn=describe_image,
446
+ inputs=[image_input, temperature_slider, top_p_slider, max_tokens_slider],
447
+ outputs=[output_text, processed_image]
448
+ )
449
+
450
+ # Reset functionality
451
+ reset_btn.click(
452
+ fn=reset_interface,
453
+ inputs=[],
454
+ outputs=[image_input, output_text, processed_image]
455
+ )
456
+
457
+ return demo
458
+
459
+ # Launch the interface
460
+ if __name__ == "__main__":
461
+ demo = create_interface()
462
+ demo.launch(
463
+ server_name="0.0.0.0",
464
+ server_port=7860,
465
+ share=False,
466
+ show_error=True
467
+ )