boroll2347's picture
clean up debug prints
019d1f9 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
import os
import spaces
# Set the device for model inference
# This will automatically use the GPU if one is available and configured
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = "pszemraj/medgemma-27b-text-heretic_med"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL, local_files_only=True)
# Load the model and ensure it's on the correct device
# We've added `load_in_8bit=True` to reduce the memory footprint.
# We've also added `offload_folder` to explicitly enable disk offloading
# for the model when it can't fit into VRAM or system RAM.
model = AutoModelForCausalLM.from_pretrained(
MODEL,
dtype=torch.bfloat16,
device_map="auto",
# load_in_8bit=True,
offload_folder="./offload_dir",
local_files_only=True,
)
if False:
def chat_interface(message, history):
"""
Main chat function to interact with the model.
"""
chat_history = list(history)
# Add the current user message to the chat history
chat_history.append({"role": "user", "content": message})
# Apply the tokenizer's chat template
prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
# Generate the response
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids.to(device),
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95
)
# Decode the response and extract the model's part
response = tokenizer.decode(outputs[0])
response = response.split("<end_of_turn>")[1].strip()
return response
# Create the Gradio interface
gr.ChatInterface(
fn=chat_interface,
type="messages",
title="MedGemma-4B-IT Medical Assistant",
description="A fine-tuned model for medical-related questions."
).launch(share=True)
@spaces.GPU(duration=60)
def extend(text, max_new_tokens, chunk_size, progress=gr.Progress()):
PREFIX = "<bos>\n" # Model just repeats the last token without this
progress(0, desc="Tokenizing...")
token_ids = tokenizer.encode(PREFIX + text, add_special_tokens=False, return_tensors="pt")
past_key_values = DynamicCache(config=model.config)
done_tokens = 0
try:
# Generate in loop to allow it to be interrupted
while done_tokens < max_new_tokens:
progress(done_tokens / max_new_tokens, desc="Generating...")
chunk_max_new_tokens = min(chunk_size, max_new_tokens - done_tokens)
new_ids = model.generate(
token_ids.to(device),
max_new_tokens=chunk_max_new_tokens,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
past_key_values=past_key_values, # continue from where we left off
)
chunk_new_tokens = new_ids.shape[1] - token_ids.shape[1]
if chunk_new_tokens < chunk_max_new_tokens:
break # Model decided to stop early
done_tokens += chunk_new_tokens
token_ids = new_ids
(unwrapped_new_ids,) = new_ids
new_text = tokenizer.decode(unwrapped_new_ids).removeprefix(PREFIX)
if not new_text.startswith(text):
yield text, "New text somehow deleted existing text!\n\n" + new_text
return
yield new_text, f"New tokens generated: {done_tokens}/{max_new_tokens}"
except Exception as e:
yield text, f"# ERROR: {e!r}"
DEBUG_ENABLED = False
if DEBUG_ENABLED:
def debug(cmd):
"""Run `result.append(...)` to display values."""
result = []
exec(cmd, globals(), locals())
return repr(result)
else:
def debug(x):
"""Debug print the input."""
return repr(x)
with gr.Blocks() as demo:
gr.Markdown("# Medical Text Generation")
gr.Markdown(f"Model in use: {MODEL}")
with gr.Tab("Extend"):
gr.Markdown("Enter some medical text, and press Generate to continue it.")
gr.Markdown("To allow interrupting the generation, it occurs in chunks, remembering the KV cache (only during the generation, not currently across executions).")
gr.Markdown("Raising the chunk size will increase latency, but might make it go faster by reducing overhead.")
document = gr.Code(
language="markdown",
interactive=True,
wrap_lines=True,
)
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=8192, step=10, value=128)
chunk_size = gr.Slider(label="Streaming Chunk Size", minimum=1, maximum=100, step=1, value=5)
with gr.Row():
generate_button = gr.Button("Generate")
abort_button = gr.Button("Abort")
generate_event = generate_button.click(
fn=extend,
inputs=[
document,
max_new_tokens,
chunk_size,
],
outputs=[
document,
gr.Code(
label="Status",
language="markdown",
interactive=False,
wrap_lines=True,
),
],
show_progress="minimal",
)
abort_button.click(
fn=None,
inputs=None,
outputs=None,
cancels=[generate_event],
)
with gr.Tab("Debug"):
gr.Interface(
fn=debug,
inputs=[gr.Code(
label=debug.__doc__,
language="python",
interactive=True,
wrap_lines=True,
)],
outputs=[gr.Code(
language="python",
wrap_lines=True,
)],
)
demo.launch()