Gemma3nSolution / app.py
aneeshm44's picture
Update app.py
089ec60 verified
import os
import sys
import tempfile
import json
import math
import timm
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
from typing import List, Union, Dict
import torchvision.transforms as transforms
# Vision Model
class TimmCNNModel(nn.Module):
def __init__(self, num_classes: int = 8, model_name: str = "efficientnet_b0"):
super().__init__()
self.backbone = timm.create_model(
'efficientnet_b0',
pretrained=True,
num_classes=0,
)
self.feature_dim = self.backbone.num_features
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(self.feature_dim, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.1),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256, num_classes)
)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
return self.backbone(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.forward_features(x)
logits = self.classifier(features)
return logits
# Projector Model
class Projector_4to3d(nn.Module):
def __init__(self, cnn_dim: int = 1280, llm_dim: int = 2048, num_heads: int = 8, dropout: float = 0.1):
super().__init__()
self.cnn_dim = cnn_dim
self.llm_dim = llm_dim
# Spatial positional embeddings for 8x8 grid
self.spatial_pos_embed = nn.Parameter(torch.randn(64, cnn_dim))
# Multi-scale feature processing
self.spatial_conv = nn.Conv2d(cnn_dim, cnn_dim // 2, 1)
self.global_pool = nn.AdaptiveAvgPool2d(1)
# Enhanced projection layers
self.input_proj = nn.Sequential(
nn.Linear(cnn_dim, llm_dim),
nn.LayerNorm(llm_dim),
nn.ReLU(),
nn.Dropout(dropout)
)
# Multi-head self-attention for spatial reasoning
self.spatial_attention = nn.MultiheadAttention(
embed_dim=llm_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Cross-attention for text-image alignment
self.cross_attention = nn.MultiheadAttention(
embed_dim=llm_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
self.norm1 = nn.LayerNorm(llm_dim)
self.norm2 = nn.LayerNorm(llm_dim)
# Enhanced FFN
self.ffn = nn.Sequential(
nn.Linear(llm_dim, llm_dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(llm_dim * 4, llm_dim),
nn.Dropout(dropout)
)
self.norm3 = nn.LayerNorm(llm_dim)
# Token compression layer
self.compress_tokens = nn.Parameter(torch.randn(32, llm_dim))
self.token_compression = nn.MultiheadAttention(
embed_dim=llm_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
def forward(self, cnn_features: torch.Tensor, text_embeddings: torch.Tensor = None) -> torch.Tensor:
batch_size = cnn_features.shape[0]
# Multi-scale processing
spatial_features = self.spatial_conv(cnn_features)
global_context = self.global_pool(cnn_features).flatten(1)
# Flatten spatial features and add positional encoding
x = einops.rearrange(cnn_features, "b c h w -> b (h w) c")
pos_embeddings = self.spatial_pos_embed.unsqueeze(0).expand(batch_size, -1, -1)
x = x + pos_embeddings
# Project to LLM dimension
x = self.input_proj(x)
# Self-attention for spatial reasoning
attended_x, spatial_attn_weights = self.spatial_attention(x, x, x)
x = self.norm1(x + attended_x)
# Cross-attention with text (if available)
if text_embeddings is not None:
text_embeddings_float = text_embeddings.float()
cross_attended, cross_attn_weights = self.cross_attention(x, text_embeddings_float, text_embeddings_float)
x = self.norm2(x + cross_attended)
# FFN
ffn_out = self.ffn(x)
x = self.norm3(x + ffn_out)
# Optional token compression
compress_queries = self.compress_tokens.unsqueeze(0).expand(batch_size, -1, -1)
compressed_x, _ = self.token_compression(compress_queries, x, x)
return compressed_x
# Main VLM Model
class Model(nn.Module):
def __init__(self, image_model, language_model, projector, tokenizer, prompt="Describe this image:"):
super().__init__()
self.image_model = image_model
self.language_model = language_model
self.projector = projector
self.tokenizer = tokenizer
self.eos_token = tokenizer.eos_token
self.prompt = prompt
device = next(self.language_model.parameters()).device
self.image_model.to(device)
self.projector.to(device)
# Create prompt embeddings
prompt_tokens = tokenizer(text=prompt, return_tensors="pt").input_ids.to(device)
prompt_embeddings = language_model.get_input_embeddings()(prompt_tokens).detach()
self.register_buffer('prompt_embeddings', prompt_embeddings)
@property
def device(self):
return next(self.parameters()).device
def generate(self, patches: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
device = self.device
patches = patches.to(device)
image_features = self.image_model.backbone.forward_features(patches)
patch_embeddings = self.projector(image_features)
patch_embeddings = patch_embeddings.to(torch.bfloat16)
embeddings = torch.cat([
self.prompt_embeddings.expand(patches.size(0), -1, -1),
patch_embeddings,
], dim=1)
prompt_mask = torch.ones(patches.size(0), self.prompt_embeddings.size(1), device=device)
patch_mask = torch.ones(patches.size(0), patch_embeddings.size(1), device=device)
attention_mask = torch.cat([prompt_mask, patch_mask], dim=1)
return self.language_model.generate(
inputs_embeds=embeddings,
attention_mask=attention_mask,
**generator_kwargs
)
vlm_model = None
tokenizer = None
transform = None
def download_and_load_models():
global vlm_model, tokenizer, transform
print("Starting model download and initialization...")
if torch.cuda.is_available():
device = torch.device("cuda:0")
print("CUDA available - using GPU")
else:
device = torch.device("cpu")
print("CUDA not available - using CPU")
repo_id = "aneeshm44/regfinal"
print(f"Downloading from repo: {repo_id}")
local_dir = tempfile.mkdtemp(prefix="regfinal_")
print(f"Local directory: {local_dir}")
try:
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=local_dir,
allow_patterns=[
"llmweights/*",
"imagemodelweights/finalcheckpoint.pth",
"projectorweights/projector.pth"
],
local_dir_use_symlinks=False,
)
print("Download completed successfully")
except Exception as e:
print(f"Download failed: {e}")
raise e
llm_path = os.path.join(local_dir, "llmweights")
image_weights_path = os.path.join(local_dir, "imagemodelweights", "finalcheckpoint.pth")
projector_weights_path = os.path.join(local_dir, "projectorweights", "projector.pth")
print("Loading language model...")
try:
language_model = AutoModelForCausalLM.from_pretrained(
llm_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
language_model.eval()
language_model.to(device)
tokenizer = AutoTokenizer.from_pretrained(llm_path)
print("Language model loaded successfully")
except Exception as e:
print(f"Language model loading failed: {e}")
raise e
print("Loading vision model...")
try:
image_model = TimmCNNModel(num_classes=8)
weights = torch.load(image_weights_path, map_location=device)
image_model.load_state_dict(weights['model_state_dict'])
for param in image_model.parameters():
param.requires_grad = False
image_model.eval()
image_model.to(device)
print("Vision model loaded successfully")
except Exception as e:
print(f"Vision model loading failed: {e}")
raise e
print("Loading projector...")
try:
projector = Projector_4to3d(cnn_dim=1280, llm_dim=2048, num_heads=8)
weights = torch.load(projector_weights_path, map_location=device)
projector.load_state_dict(weights)
for param in projector.parameters():
param.requires_grad = False
projector.eval()
projector.to(device)
print("Projector loaded successfully")
except Exception as e:
print(f"Projector loading failed: {e}")
raise e
print("Creating VLM model...")
try:
vlm_model = Model(image_model, language_model, projector, tokenizer, prompt="Describe this image:")
vlm_model = vlm_model.to(device)
print("VLM model created successfully")
except Exception as e:
print(f"VLM model creation failed: {e}")
raise e
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
print("All models loaded successfully!")
def tensor_to_pil_image(tensor):
img_tensor = tensor.squeeze(0)
img_tensor = torch.clamp(img_tensor, 0, 1)
img_array = img_tensor.permute(1, 2, 0).numpy()
img_array = (img_array * 255).astype(np.uint8)
return Image.fromarray(img_array)
def on_image_upload(image):
if image is not None:
return "Image processed, click 'Generate Report' to produce report."
else:
return "Models are loaded, upload the Image to get started."
def describe_image(image, temperature, top_p, max_tokens, progress=gr.Progress()):
global vlm_model, tokenizer, transform
if vlm_model is None:
return "Models not loaded yet. Please wait for initialization to complete.", None
if image is None:
return "Please upload an image.", None
try:
progress(0.1, desc="Starting image processing...")
# Preprocess image
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif hasattr(image, 'convert'):
image = image.convert('RGB')
progress(0.3, desc="Applying image transformations...")
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
processed_image = tensor_to_pil_image(image_tensor)
progress(0.5, desc="Setting up generation parameters...")
# Generation parameters
generator_kwargs = {
"max_new_tokens": int(max_tokens),
"do_sample": True,
"temperature": float(temperature),
"top_p": float(top_p),
"pad_token_id": tokenizer.eos_token_id
}
progress(0.7, desc="Generating pathology report...")
# Generate description
with torch.no_grad():
output_ids = vlm_model.generate(image_tensor, generator_kwargs)
text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
progress(0.9, desc="Finalizing report...")
if "Describe this image:" in text:
description = text.split("Describe this image:")[-1].strip()
else:
description = text.strip()
result_text = description if description else "Unable to generate description."
progress(1.0, desc="Complete!")
return result_text, processed_image
except Exception as e:
return f"Error processing image: {str(e)}", None
def reset_interface():
return None, "Models are loaded, upload the WSI file to get started.", None
try:
download_and_load_models()
initial_status = "Models are loaded, upload the WSI file to get started."
except Exception as e:
initial_status = f"Failed to load models: {str(e)}"
def create_interface():
with gr.Blocks(title="WSI Pathology Report using Gemma3n") as demo:
gr.Markdown("# WSI Pathology Report using Gemma3n")
gr.Markdown("Upload a pathology WSI to get concise a report")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload WSI file")
# Generation parameters
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.6,
step=0.1,
label="Temperature",
info="Lower values give consistent results and Higher values produce creative results"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top-p",
info="Lower values use a more focused vocabulary for sampling compared to a more diverse vocabulary in Higher values"
)
max_tokens_slider = gr.Slider(
minimum=10,
maximum=200,
value=100,
step=10,
label="Max Tokens for generation"
)
with gr.Row():
submit_btn = gr.Button("Generate Report", variant="primary")
reset_btn = gr.Button("Reset", variant="secondary")
with gr.Column():
output_text = gr.Textbox(
label="Pathology Report",
lines=8,
value=initial_status,
show_copy_button=True
)
processed_image = gr.Image(
label="Processed WSI",
show_download_button=True
)
image_input.change(
fn=on_image_upload,
inputs=[image_input],
outputs=[output_text]
)
submit_btn.click(
fn=describe_image,
inputs=[image_input, temperature_slider, top_p_slider, max_tokens_slider],
outputs=[output_text, processed_image],
show_progress=True
)
reset_btn.click(
fn=reset_interface,
inputs=[],
outputs=[image_input, output_text, processed_image]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)