aneeshm44 commited on
Commit
089ec60
·
verified ·
1 Parent(s): 0dc73ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -66
app.py CHANGED
@@ -15,11 +15,8 @@ from PIL import Image
15
  import gradio as gr
16
  from huggingface_hub import snapshot_download
17
  from typing import List, Union, Dict
 
18
 
19
- # Configuration
20
- class CFG:
21
- MAX_LENGTH = 512
22
- LABEL_MASK = -100
23
 
24
  # Vision Model
25
  class TimmCNNModel(nn.Module):
@@ -44,7 +41,7 @@ class TimmCNNModel(nn.Module):
44
  nn.ReLU(inplace=True),
45
  nn.Linear(256, num_classes)
46
  )
47
-
48
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
49
  return self.backbone(x)
50
 
@@ -211,17 +208,15 @@ class Model(nn.Module):
211
  **generator_kwargs
212
  )
213
 
214
- # Global variables for models
215
  vlm_model = None
216
  tokenizer = None
 
217
 
218
  def download_and_load_models():
219
- """Download models and load them into memory"""
220
- global vlm_model, tokenizer
221
 
222
  print("Starting model download and initialization...")
223
 
224
- # Set device
225
  if torch.cuda.is_available():
226
  device = torch.device("cuda:0")
227
  print("CUDA available - using GPU")
@@ -229,7 +224,6 @@ def download_and_load_models():
229
  device = torch.device("cpu")
230
  print("CUDA not available - using CPU")
231
 
232
- # Download weights
233
  repo_id = "aneeshm44/regfinal"
234
  print(f"Downloading from repo: {repo_id}")
235
 
@@ -253,12 +247,10 @@ def download_and_load_models():
253
  print(f"Download failed: {e}")
254
  raise e
255
 
256
- # Set paths
257
  llm_path = os.path.join(local_dir, "llmweights")
258
  image_weights_path = os.path.join(local_dir, "imagemodelweights", "finalcheckpoint.pth")
259
  projector_weights_path = os.path.join(local_dir, "projectorweights", "projector.pth")
260
 
261
- # Load Language Model
262
  print("Loading language model...")
263
  try:
264
  language_model = AutoModelForCausalLM.from_pretrained(
@@ -276,7 +268,6 @@ def download_and_load_models():
276
  print(f"Language model loading failed: {e}")
277
  raise e
278
 
279
- # Load Vision Model
280
  print("Loading vision model...")
281
  try:
282
  image_model = TimmCNNModel(num_classes=8)
@@ -292,7 +283,6 @@ def download_and_load_models():
292
  print(f"Vision model loading failed: {e}")
293
  raise e
294
 
295
- # Load Projector
296
  print("Loading projector...")
297
  try:
298
  projector = Projector_4to3d(cnn_dim=1280, llm_dim=2048, num_heads=8)
@@ -308,7 +298,6 @@ def download_and_load_models():
308
  print(f"Projector loading failed: {e}")
309
  raise e
310
 
311
- # Create VLM Model
312
  print("Creating VLM model...")
313
  try:
314
  vlm_model = Model(image_model, language_model, projector, tokenizer, prompt="Describe this image:")
@@ -318,38 +307,29 @@ def download_and_load_models():
318
  print(f"VLM model creation failed: {e}")
319
  raise e
320
 
321
- print("All models loaded successfully!")
322
-
323
- def pil_to_tensor(image):
324
- """Convert PIL image directly to tensor without normalization"""
325
- # Convert PIL to numpy array
326
- img_array = np.array(image)
327
-
328
- # Convert to tensor and normalize to [0, 1] range
329
- img_tensor = torch.from_numpy(img_array).float() / 255.0
330
 
331
- # Rearrange from HWC to CHW format
332
- img_tensor = img_tensor.permute(2, 0, 1)
333
-
334
- # Add batch dimension
335
- img_tensor = img_tensor.unsqueeze(0)
336
-
337
- return img_tensor
338
 
339
  def tensor_to_pil_image(tensor):
340
- """Convert tensor to PIL image for display"""
341
- # Remove batch dimension and clamp values
342
  img_tensor = tensor.squeeze(0)
343
  img_tensor = torch.clamp(img_tensor, 0, 1)
344
 
345
- # Convert to PIL
346
  img_array = img_tensor.permute(1, 2, 0).numpy()
347
  img_array = (img_array * 255).astype(np.uint8)
348
  return Image.fromarray(img_array)
349
 
 
 
 
 
 
 
350
  def describe_image(image, temperature, top_p, max_tokens, progress=gr.Progress()):
351
- """Generate description for uploaded image"""
352
- global vlm_model, tokenizer
353
 
354
  if vlm_model is None:
355
  return "Models not loaded yet. Please wait for initialization to complete.", None
@@ -358,7 +338,6 @@ def describe_image(image, temperature, top_p, max_tokens, progress=gr.Progress()
358
  return "Please upload an image.", None
359
 
360
  try:
361
- # Progress tracking
362
  progress(0.1, desc="Starting image processing...")
363
 
364
  # Preprocess image
@@ -367,12 +346,10 @@ def describe_image(image, temperature, top_p, max_tokens, progress=gr.Progress()
367
  elif hasattr(image, 'convert'):
368
  image = image.convert('RGB')
369
 
370
- progress(0.3, desc="Converting image to tensor...")
371
 
372
- # Convert PIL image directly to tensor
373
- image_tensor = pil_to_tensor(image)
374
 
375
- # Convert tensor to PIL image for display
376
  processed_image = tensor_to_pil_image(image_tensor)
377
 
378
  progress(0.5, desc="Setting up generation parameters...")
@@ -395,7 +372,6 @@ def describe_image(image, temperature, top_p, max_tokens, progress=gr.Progress()
395
 
396
  progress(0.9, desc="Finalizing report...")
397
 
398
- # Clean up the output (remove the prompt)
399
  if "Describe this image:" in text:
400
  description = text.split("Describe this image:")[-1].strip()
401
  else:
@@ -411,35 +387,32 @@ def describe_image(image, temperature, top_p, max_tokens, progress=gr.Progress()
411
  return f"Error processing image: {str(e)}", None
412
 
413
  def reset_interface():
414
- """Reset the interface by clearing all outputs"""
415
- return None, "Models loaded successfully! Upload an image to get started.", None
416
 
417
- # Initialize models when the script starts
418
  try:
419
  download_and_load_models()
420
- initial_status = "Models loaded successfully! Upload an image to get started."
421
  except Exception as e:
422
  initial_status = f"Failed to load models: {str(e)}"
423
 
424
- # Create Gradio Interface
425
  def create_interface():
426
  with gr.Blocks(title="WSI Pathology Report using Gemma3n") as demo:
427
  gr.Markdown("# WSI Pathology Report using Gemma3n")
428
- gr.Markdown("Upload a pathology image and get an AI-generated pathology report.")
429
 
430
  with gr.Row():
431
  with gr.Column():
432
- image_input = gr.Image(type="pil", label="Upload WSI Image")
433
 
434
  # Generation parameters
435
  with gr.Row():
436
  temperature_slider = gr.Slider(
437
  minimum=0.1,
438
  maximum=1.0,
439
- value=0.4,
440
  step=0.1,
441
  label="Temperature",
442
- info="Lower values = more focused/consistent, Higher values = more creative/varied"
443
  )
444
 
445
  top_p_slider = gr.Slider(
@@ -448,7 +421,7 @@ def create_interface():
448
  value=0.9,
449
  step=0.1,
450
  label="Top-p",
451
- info="Lower values = more focused vocabulary, Higher values = more diverse vocabulary"
452
  )
453
 
454
  max_tokens_slider = gr.Slider(
@@ -456,7 +429,7 @@ def create_interface():
456
  maximum=200,
457
  value=100,
458
  step=10,
459
- label="Max Tokens"
460
  )
461
 
462
  with gr.Row():
@@ -472,27 +445,23 @@ def create_interface():
472
  )
473
 
474
  processed_image = gr.Image(
475
- label="Processed Image Tensor",
476
  show_download_button=True
477
  )
478
 
479
- # Event handlers
480
- submit_btn.click(
481
- fn=describe_image,
482
- inputs=[image_input, temperature_slider, top_p_slider, max_tokens_slider],
483
- outputs=[output_text, processed_image],
484
- show_progress=True
485
  )
486
 
487
- # Auto-generate on image upload
488
- image_input.change(
489
  fn=describe_image,
490
  inputs=[image_input, temperature_slider, top_p_slider, max_tokens_slider],
491
  outputs=[output_text, processed_image],
492
  show_progress=True
493
  )
494
 
495
- # Reset functionality
496
  reset_btn.click(
497
  fn=reset_interface,
498
  inputs=[],
@@ -501,7 +470,6 @@ def create_interface():
501
 
502
  return demo
503
 
504
- # Launch the interface
505
  if __name__ == "__main__":
506
  demo = create_interface()
507
  demo.launch(
@@ -509,5 +477,4 @@ if __name__ == "__main__":
509
  server_port=7860,
510
  share=False,
511
  show_error=True
512
- )
513
-
 
15
  import gradio as gr
16
  from huggingface_hub import snapshot_download
17
  from typing import List, Union, Dict
18
+ import torchvision.transforms as transforms
19
 
 
 
 
 
20
 
21
  # Vision Model
22
  class TimmCNNModel(nn.Module):
 
41
  nn.ReLU(inplace=True),
42
  nn.Linear(256, num_classes)
43
  )
44
+
45
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
46
  return self.backbone(x)
47
 
 
208
  **generator_kwargs
209
  )
210
 
 
211
  vlm_model = None
212
  tokenizer = None
213
+ transform = None
214
 
215
  def download_and_load_models():
216
+ global vlm_model, tokenizer, transform
 
217
 
218
  print("Starting model download and initialization...")
219
 
 
220
  if torch.cuda.is_available():
221
  device = torch.device("cuda:0")
222
  print("CUDA available - using GPU")
 
224
  device = torch.device("cpu")
225
  print("CUDA not available - using CPU")
226
 
 
227
  repo_id = "aneeshm44/regfinal"
228
  print(f"Downloading from repo: {repo_id}")
229
 
 
247
  print(f"Download failed: {e}")
248
  raise e
249
 
 
250
  llm_path = os.path.join(local_dir, "llmweights")
251
  image_weights_path = os.path.join(local_dir, "imagemodelweights", "finalcheckpoint.pth")
252
  projector_weights_path = os.path.join(local_dir, "projectorweights", "projector.pth")
253
 
 
254
  print("Loading language model...")
255
  try:
256
  language_model = AutoModelForCausalLM.from_pretrained(
 
268
  print(f"Language model loading failed: {e}")
269
  raise e
270
 
 
271
  print("Loading vision model...")
272
  try:
273
  image_model = TimmCNNModel(num_classes=8)
 
283
  print(f"Vision model loading failed: {e}")
284
  raise e
285
 
 
286
  print("Loading projector...")
287
  try:
288
  projector = Projector_4to3d(cnn_dim=1280, llm_dim=2048, num_heads=8)
 
298
  print(f"Projector loading failed: {e}")
299
  raise e
300
 
 
301
  print("Creating VLM model...")
302
  try:
303
  vlm_model = Model(image_model, language_model, projector, tokenizer, prompt="Describe this image:")
 
307
  print(f"VLM model creation failed: {e}")
308
  raise e
309
 
310
+ transform = transforms.Compose([
311
+ transforms.Resize((256, 256)),
312
+ transforms.ToTensor(),
313
+ ])
 
 
 
 
 
314
 
315
+ print("All models loaded successfully!")
 
 
 
 
 
 
316
 
317
  def tensor_to_pil_image(tensor):
 
 
318
  img_tensor = tensor.squeeze(0)
319
  img_tensor = torch.clamp(img_tensor, 0, 1)
320
 
 
321
  img_array = img_tensor.permute(1, 2, 0).numpy()
322
  img_array = (img_array * 255).astype(np.uint8)
323
  return Image.fromarray(img_array)
324
 
325
+ def on_image_upload(image):
326
+ if image is not None:
327
+ return "Image processed, click 'Generate Report' to produce report."
328
+ else:
329
+ return "Models are loaded, upload the Image to get started."
330
+
331
  def describe_image(image, temperature, top_p, max_tokens, progress=gr.Progress()):
332
+ global vlm_model, tokenizer, transform
 
333
 
334
  if vlm_model is None:
335
  return "Models not loaded yet. Please wait for initialization to complete.", None
 
338
  return "Please upload an image.", None
339
 
340
  try:
 
341
  progress(0.1, desc="Starting image processing...")
342
 
343
  # Preprocess image
 
346
  elif hasattr(image, 'convert'):
347
  image = image.convert('RGB')
348
 
349
+ progress(0.3, desc="Applying image transformations...")
350
 
351
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
 
352
 
 
353
  processed_image = tensor_to_pil_image(image_tensor)
354
 
355
  progress(0.5, desc="Setting up generation parameters...")
 
372
 
373
  progress(0.9, desc="Finalizing report...")
374
 
 
375
  if "Describe this image:" in text:
376
  description = text.split("Describe this image:")[-1].strip()
377
  else:
 
387
  return f"Error processing image: {str(e)}", None
388
 
389
  def reset_interface():
390
+ return None, "Models are loaded, upload the WSI file to get started.", None
 
391
 
 
392
  try:
393
  download_and_load_models()
394
+ initial_status = "Models are loaded, upload the WSI file to get started."
395
  except Exception as e:
396
  initial_status = f"Failed to load models: {str(e)}"
397
 
 
398
  def create_interface():
399
  with gr.Blocks(title="WSI Pathology Report using Gemma3n") as demo:
400
  gr.Markdown("# WSI Pathology Report using Gemma3n")
401
+ gr.Markdown("Upload a pathology WSI to get concise a report")
402
 
403
  with gr.Row():
404
  with gr.Column():
405
+ image_input = gr.Image(type="pil", label="Upload WSI file")
406
 
407
  # Generation parameters
408
  with gr.Row():
409
  temperature_slider = gr.Slider(
410
  minimum=0.1,
411
  maximum=1.0,
412
+ value=0.6,
413
  step=0.1,
414
  label="Temperature",
415
+ info="Lower values give consistent results and Higher values produce creative results"
416
  )
417
 
418
  top_p_slider = gr.Slider(
 
421
  value=0.9,
422
  step=0.1,
423
  label="Top-p",
424
+ info="Lower values use a more focused vocabulary for sampling compared to a more diverse vocabulary in Higher values"
425
  )
426
 
427
  max_tokens_slider = gr.Slider(
 
429
  maximum=200,
430
  value=100,
431
  step=10,
432
+ label="Max Tokens for generation"
433
  )
434
 
435
  with gr.Row():
 
445
  )
446
 
447
  processed_image = gr.Image(
448
+ label="Processed WSI",
449
  show_download_button=True
450
  )
451
 
452
+ image_input.change(
453
+ fn=on_image_upload,
454
+ inputs=[image_input],
455
+ outputs=[output_text]
 
 
456
  )
457
 
458
+ submit_btn.click(
 
459
  fn=describe_image,
460
  inputs=[image_input, temperature_slider, top_p_slider, max_tokens_slider],
461
  outputs=[output_text, processed_image],
462
  show_progress=True
463
  )
464
 
 
465
  reset_btn.click(
466
  fn=reset_interface,
467
  inputs=[],
 
470
 
471
  return demo
472
 
 
473
  if __name__ == "__main__":
474
  demo = create_interface()
475
  demo.launch(
 
477
  server_port=7860,
478
  share=False,
479
  show_error=True
480
+ )