Brain / app.py
sohiebwedyan's picture
Update app.py
714a740 verified
import os
import torch
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import gradio as gr
# -----------------------------------------------------------------------------
# Load HF token from environment
# -----------------------------------------------------------------------------
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not HF_TOKEN:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set")
# -----------------------------------------------------------------------------
# 1) GPU inference function
# -----------------------------------------------------------------------------
def run_inference_on_gpu(
model_id: str,
image: Image.Image,
prompt: str = "caption",
max_new_tokens: int = 100
) -> str:
# ensure CUDA is available
assert torch.cuda.is_available(), "CUDA not available—check your PyTorch installation!"
device = torch.device("cuda")
dtype = torch.float16
# load tokenizer + model onto GPU with explicit token
processor = AutoProcessor.from_pretrained(model_id, use_auth_token=HF_TOKEN)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=None,
use_auth_token=HF_TOKEN
).to(device).eval()
# build multimodal prompt
image_tokens = "<image>"
multimodal_prompt = f"{image_tokens} {prompt}"
# prepare inputs
inputs = processor(
text=multimodal_prompt,
images=[image],
padding="longest",
return_tensors="pt",
do_convert_rgb=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=3,
do_sample=False,
)
# decode
return processor.decode(outputs[0].cpu(), skip_special_tokens=True)
# -----------------------------------------------------------------------------
# 2) Gradio UI
# -----------------------------------------------------------------------------
MODEL_ID = "mychen76/paligemma-3b-mix-448-med_30k-ct-brain"
def caption_fn(image, prompt, max_tokens):
"""
Gradio callback: takes a PIL image, a text prompt, and
max tokens → returns the generated caption.
"""
return run_inference_on_gpu(
model_id=MODEL_ID,
image=image,
prompt=prompt,
max_new_tokens=max_tokens,
)
demo = gr.Interface(
fn=caption_fn,
inputs=[
gr.Image(type="pil", label="Upload CT Scan"),
gr.Textbox(
value="What do you see in this CT scan?",
label="Prompt"
),
gr.Slider(
minimum=10, maximum=300, step=10, value=100,
label="Max New Tokens"
),
],
outputs=gr.Textbox(label="Model Caption"),
title="PaliGemma CT-Scan Captioning",
description=(
"Upload a brain CT scan (or any image), write a short prompt, "
"and let the PaliGemma model describe what it sees."
),
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch(share=False, show_api=False)