Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
| import time | |
| import torch | |
| import gradio as gr | |
| from transformers import pipeline | |
| from PIL import Image | |
| from huggingface_hub import login | |
| import spaces | |
| # Disable torch.compile (Dynamo) to fix the Gemma 3 bug | |
| # This prevents a known incompatibility with new model architectures. | |
| import torch._dynamo | |
| torch._dynamo.disable() | |
| # It uses HF_TOKEN secret that I stored in your Space's settings | |
| # to give authentication for accessing MedGemma | |
| login(token=os.getenv("HF_TOKEN")) | |
| # Load the Model (done only once) | |
| pipe = pipeline( | |
| "image-text-to-text", | |
| model="google/medgemma-4b-it", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| do_sample=False, | |
| top_k=64, | |
| top_p=0.95, | |
| ) | |
| def create_user_message(user_input: dict) -> dict: | |
| """ | |
| Converts user input into a message dictionary | |
| * This message dictionary is the correct format | |
| that MedGemma accepts. | |
| Args: | |
| ----- | |
| - user_input: A dictionary of text and images. | |
| Text key: "text", Images key: "files". | |
| Images are represented by their paths in a list. | |
| {"text": "What do you think?", "files": [...]} | |
| """ | |
| user_text = user_input["text"] | |
| user_images = user_input["files"] | |
| user_content = [] # referring to "content" field | |
| if user_text: | |
| user_content.append({"type": "text", "text": user_text}) | |
| # Adding current images from user | |
| if user_images: | |
| for img_path in user_images: | |
| image = Image.open(img_path) | |
| user_content.append({"type": "image", "image": image}) | |
| # Create the full user message object | |
| user_message = {"role": "user", "content": user_content} | |
| return user_message | |
| def process_history(history: list[dict]) -> list[dict]: | |
| messages = [] | |
| user_content = [] | |
| for message in history: | |
| if message["role"] == "user": | |
| content = message["content"] | |
| if isinstance(content, str): | |
| user_content.append({"type": "text", "text": content}) | |
| else: | |
| image = Image.open(content[0]).convert("RGB") | |
| user_content.append({"type": "image", "image": image}) | |
| else: | |
| # Before adding assistant message, I will convert all ... | |
| # accumulated user contents to a user message and add it. | |
| if user_content: | |
| messages.append({"role": "user", "content": user_content}) | |
| user_content = [] | |
| assis_content = [{"type": "text", "text": message["content"]}] | |
| messages.append({"role": "assistant", "content": assis_content}) | |
| return messages | |
| def parse_output(output: list) -> dict: | |
| output_dict = output[0] | |
| if not "generated_text" in output_dict: | |
| raise ValueError("Invalid model output") | |
| # list = previous messages + response of current query | |
| new_history = output[0]["generated_text"] | |
| if not isinstance(new_history, list): | |
| raise TypeError(f"History is not a list, it is {type(new_history)}") | |
| if len(new_history) < 3: | |
| raise ValueError("History should include at least 3 messages, which" \ | |
| "are system prompt, user query and given response") | |
| # generated assistant message | |
| # {"role": "assistant", "content": "..."} | |
| # content is not a list, it is a string | |
| assistant_message = new_history[-1] | |
| content = assistant_message["content"] | |
| return content | |
| # Define the inference Function with state management | |
| def chat( | |
| user_input: dict, | |
| history: list[dict], | |
| sys_prompt: str, | |
| max_tokens: int | |
| ): | |
| """ | |
| Args: | |
| ----- | |
| - user_input: the user's most recent message (dict for | |
| multimodal case, str for non-multimodal) | |
| - history: a list of gradio messages. Each message | |
| refers to a dictionary of "role" and "content". | |
| - sys_prompt: It sets the initial persona for the model. | |
| - max_tokens: The maximum number of new tokens for the model | |
| to generate. | |
| Gradio Messages vs MedGemma Messages: | |
| ------------------------------------- | |
| * They are almost same, but content is represented in | |
| different ways. | |
| * Content is either a tuple of one image path or a | |
| string of text in gradio message. MedGemma accepts the | |
| content as a list of dictionaries. | |
| - Why list, because to handle multiple contents together. | |
| - Why dict, because to handle the type of content and | |
| content itself together. | |
| Gradio Messages: [ | |
| {"role": "user", "content": ("cat1.png")} | |
| {"role": "user", "content": ("cat2.png")} | |
| {"role": "user", "content": "How do they resemble ?} | |
| ] | |
| MedGemma Messages: [ | |
| {"role": "user", "content": [ | |
| {"type": "image", "image": PIL.Image}, | |
| {"type": "image", "image": PIL.Image}, | |
| {"type": "text", "text": How do they resemble ?} | |
| ] | |
| } | |
| ] | |
| """ | |
| llm_messages = [] | |
| if sys_prompt: | |
| sys_content = [{"type": "text", "text": sys_prompt}] | |
| llm_messages.append({"role": "system", "content": sys_content}) | |
| # history contains all the messages up to now | |
| # convert history messages from gradio format LLM format | |
| llm_messages.extend(process_history(history)) | |
| # changing user input to a message | |
| user_message = create_user_message(user_input) | |
| llm_messages.append(user_message) | |
| # Generate a response from the model | |
| # Note: We pass the complete message history to the pipe | |
| output = pipe(llm_messages, max_new_tokens=max_tokens) | |
| # Extract content of assistant message from model output | |
| content = parse_output(output) | |
| for i in range(len(content)): | |
| time.sleep(0.01) | |
| yield content[: i + 1] | |
| description = "MedGemma is a variation of Gemma 3 architecture. \ | |
| It was specifically optimized for medical text and image \ | |
| comprehension tasks. You can upload your images and ask \ | |
| text-based questions. " | |
| user_img = "https://huggingface.co/avatars/2c4ba9b7cf3a77322929737c35252857.svg" | |
| gemma_img = "https://huggingface.co/spaces/goktug14/MedGemma/resolve/main/images/gemma.jpg" | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| type="messages", | |
| multimodal=True, | |
| title="MedGemma Medical Assistant", | |
| description=description, | |
| chatbot=gr.Chatbot(type="messages", height=300, scale=1, avatar_images=(user_img, gemma_img)), | |
| textbox=gr.MultimodalTextbox(file_types=["image"], file_count="multiple"), | |
| css_paths=["style.css"], | |
| additional_inputs=[ | |
| gr.Textbox("You are a helpful medical AI assistant.", label="System Prompt"), | |
| gr.Slider(minimum=100, maximum=2048, value=300, step=1), | |
| ]) | |
| demo.launch() |