MedGemma / app.py
goktug14's picture
streaming
8a3d4ec
import os
os.environ["TORCH_COMPILE_DISABLE"] = "1"
import time
import torch
import gradio as gr
from transformers import pipeline
from PIL import Image
from huggingface_hub import login
import spaces
# Disable torch.compile (Dynamo) to fix the Gemma 3 bug
# This prevents a known incompatibility with new model architectures.
import torch._dynamo
torch._dynamo.disable()
# It uses HF_TOKEN secret that I stored in your Space's settings
# to give authentication for accessing MedGemma
login(token=os.getenv("HF_TOKEN"))
# Load the Model (done only once)
pipe = pipeline(
"image-text-to-text",
model="google/medgemma-4b-it",
torch_dtype=torch.bfloat16,
device_map="auto",
do_sample=False,
top_k=64,
top_p=0.95,
)
def create_user_message(user_input: dict) -> dict:
"""
Converts user input into a message dictionary
* This message dictionary is the correct format
that MedGemma accepts.
Args:
-----
- user_input: A dictionary of text and images.
Text key: "text", Images key: "files".
Images are represented by their paths in a list.
{"text": "What do you think?", "files": [...]}
"""
user_text = user_input["text"]
user_images = user_input["files"]
user_content = [] # referring to "content" field
if user_text:
user_content.append({"type": "text", "text": user_text})
# Adding current images from user
if user_images:
for img_path in user_images:
image = Image.open(img_path)
user_content.append({"type": "image", "image": image})
# Create the full user message object
user_message = {"role": "user", "content": user_content}
return user_message
def process_history(history: list[dict]) -> list[dict]:
messages = []
user_content = []
for message in history:
if message["role"] == "user":
content = message["content"]
if isinstance(content, str):
user_content.append({"type": "text", "text": content})
else:
image = Image.open(content[0]).convert("RGB")
user_content.append({"type": "image", "image": image})
else:
# Before adding assistant message, I will convert all ...
# accumulated user contents to a user message and add it.
if user_content:
messages.append({"role": "user", "content": user_content})
user_content = []
assis_content = [{"type": "text", "text": message["content"]}]
messages.append({"role": "assistant", "content": assis_content})
return messages
def parse_output(output: list) -> dict:
output_dict = output[0]
if not "generated_text" in output_dict:
raise ValueError("Invalid model output")
# list = previous messages + response of current query
new_history = output[0]["generated_text"]
if not isinstance(new_history, list):
raise TypeError(f"History is not a list, it is {type(new_history)}")
if len(new_history) < 3:
raise ValueError("History should include at least 3 messages, which" \
"are system prompt, user query and given response")
# generated assistant message
# {"role": "assistant", "content": "..."}
# content is not a list, it is a string
assistant_message = new_history[-1]
content = assistant_message["content"]
return content
# Define the inference Function with state management
@spaces.GPU(duration=120, timeout=300)
def chat(
user_input: dict,
history: list[dict],
sys_prompt: str,
max_tokens: int
):
"""
Args:
-----
- user_input: the user's most recent message (dict for
multimodal case, str for non-multimodal)
- history: a list of gradio messages. Each message
refers to a dictionary of "role" and "content".
- sys_prompt: It sets the initial persona for the model.
- max_tokens: The maximum number of new tokens for the model
to generate.
Gradio Messages vs MedGemma Messages:
-------------------------------------
* They are almost same, but content is represented in
different ways.
* Content is either a tuple of one image path or a
string of text in gradio message. MedGemma accepts the
content as a list of dictionaries.
- Why list, because to handle multiple contents together.
- Why dict, because to handle the type of content and
content itself together.
Gradio Messages: [
{"role": "user", "content": ("cat1.png")}
{"role": "user", "content": ("cat2.png")}
{"role": "user", "content": "How do they resemble ?}
]
MedGemma Messages: [
{"role": "user", "content": [
{"type": "image", "image": PIL.Image},
{"type": "image", "image": PIL.Image},
{"type": "text", "text": How do they resemble ?}
]
}
]
"""
llm_messages = []
if sys_prompt:
sys_content = [{"type": "text", "text": sys_prompt}]
llm_messages.append({"role": "system", "content": sys_content})
# history contains all the messages up to now
# convert history messages from gradio format LLM format
llm_messages.extend(process_history(history))
# changing user input to a message
user_message = create_user_message(user_input)
llm_messages.append(user_message)
# Generate a response from the model
# Note: We pass the complete message history to the pipe
output = pipe(llm_messages, max_new_tokens=max_tokens)
# Extract content of assistant message from model output
content = parse_output(output)
for i in range(len(content)):
time.sleep(0.01)
yield content[: i + 1]
description = "MedGemma is a variation of Gemma 3 architecture. \
It was specifically optimized for medical text and image \
comprehension tasks. You can upload your images and ask \
text-based questions. "
user_img = "https://huggingface.co/avatars/2c4ba9b7cf3a77322929737c35252857.svg"
gemma_img = "https://huggingface.co/spaces/goktug14/MedGemma/resolve/main/images/gemma.jpg"
demo = gr.ChatInterface(
fn=chat,
type="messages",
multimodal=True,
title="MedGemma Medical Assistant",
description=description,
chatbot=gr.Chatbot(type="messages", height=300, scale=1, avatar_images=(user_img, gemma_img)),
textbox=gr.MultimodalTextbox(file_types=["image"], file_count="multiple"),
css_paths=["style.css"],
additional_inputs=[
gr.Textbox("You are a helpful medical AI assistant.", label="System Prompt"),
gr.Slider(minimum=100, maximum=2048, value=300, step=1),
])
demo.launch()