Spaces:
Running on Zero
Running on Zero
File size: 6,165 Bytes
8863499 34ec432 8863499 9bfe19a 8863499 29c6a08 8863499 29c6a08 8863499 29c6a08 3da8a03 8863499 3da8a03 ff973d2 8863499 3da8a03 8863499 3da8a03 8863499 e56e240 34ec432 10d0e59 190e12e 34ec432 10d0e59 34ec432 10d0e59 34ec432 29c6a08 3da8a03 190e12e 29c6a08 190e12e 29c6a08 190e12e ff973d2 190e12e ff973d2 190e12e f200aaf 34ec432 190e12e f200aaf 34ec432 190e12e 10d0e59 190e12e 10d0e59 190e12e 10d0e59 190e12e 34ec432 29c6a08 34ec432 190e12e 3da8a03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
import os
import spaces
# Set the device for model inference
# This will automatically use the GPU if one is available and configured
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = "pszemraj/medgemma-27b-text-heretic_med"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL, local_files_only=True)
# Load the model and ensure it's on the correct device
# We've added `load_in_8bit=True` to reduce the memory footprint.
# We've also added `offload_folder` to explicitly enable disk offloading
# for the model when it can't fit into VRAM or system RAM.
model = AutoModelForCausalLM.from_pretrained(
MODEL,
dtype=torch.bfloat16,
device_map="auto",
# load_in_8bit=True,
offload_folder="./offload_dir",
local_files_only=True,
)
if False:
def chat_interface(message, history):
"""
Main chat function to interact with the model.
"""
chat_history = list(history)
# Add the current user message to the chat history
chat_history.append({"role": "user", "content": message})
# Apply the tokenizer's chat template
prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
# Generate the response
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids.to(device),
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95
)
# Decode the response and extract the model's part
response = tokenizer.decode(outputs[0])
response = response.split("<end_of_turn>")[1].strip()
return response
# Create the Gradio interface
gr.ChatInterface(
fn=chat_interface,
type="messages",
title="MedGemma-4B-IT Medical Assistant",
description="A fine-tuned model for medical-related questions."
).launch(share=True)
@spaces.GPU(duration=60)
def extend(text, max_new_tokens, chunk_size, progress=gr.Progress()):
PREFIX = "<bos>\n" # Model just repeats the last token without this
progress(0, desc="Tokenizing...")
token_ids = tokenizer.encode(PREFIX + text, add_special_tokens=False, return_tensors="pt")
past_key_values = DynamicCache(config=model.config)
done_tokens = 0
try:
# Generate in loop to allow it to be interrupted
while done_tokens < max_new_tokens:
progress(done_tokens / max_new_tokens, desc="Generating...")
chunk_max_new_tokens = min(chunk_size, max_new_tokens - done_tokens)
new_ids = model.generate(
token_ids.to(device),
max_new_tokens=chunk_max_new_tokens,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
past_key_values=past_key_values, # continue from where we left off
)
chunk_new_tokens = new_ids.shape[1] - token_ids.shape[1]
if chunk_new_tokens < chunk_max_new_tokens:
break # Model decided to stop early
done_tokens += chunk_new_tokens
token_ids = new_ids
(unwrapped_new_ids,) = new_ids
new_text = tokenizer.decode(unwrapped_new_ids).removeprefix(PREFIX)
if not new_text.startswith(text):
yield text, "New text somehow deleted existing text!\n\n" + new_text
return
yield new_text, f"New tokens generated: {done_tokens}/{max_new_tokens}"
except Exception as e:
yield text, f"# ERROR: {e!r}"
DEBUG_ENABLED = False
if DEBUG_ENABLED:
def debug(cmd):
"""Run `result.append(...)` to display values."""
result = []
exec(cmd, globals(), locals())
return repr(result)
else:
def debug(x):
"""Debug print the input."""
return repr(x)
with gr.Blocks() as demo:
gr.Markdown("# Medical Text Generation")
gr.Markdown(f"Model in use: {MODEL}")
with gr.Tab("Extend"):
gr.Markdown("Enter some medical text, and press Generate to continue it.")
gr.Markdown("To allow interrupting the generation, it occurs in chunks, remembering the KV cache (only during the generation, not currently across executions).")
gr.Markdown("Raising the chunk size will increase latency, but might make it go faster by reducing overhead.")
document = gr.Code(
language="markdown",
interactive=True,
wrap_lines=True,
)
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=8192, step=10, value=128)
chunk_size = gr.Slider(label="Streaming Chunk Size", minimum=1, maximum=100, step=1, value=5)
with gr.Row():
generate_button = gr.Button("Generate")
abort_button = gr.Button("Abort")
generate_event = generate_button.click(
fn=extend,
inputs=[
document,
max_new_tokens,
chunk_size,
],
outputs=[
document,
gr.Code(
label="Status",
language="markdown",
interactive=False,
wrap_lines=True,
),
],
show_progress="minimal",
)
abort_button.click(
fn=None,
inputs=None,
outputs=None,
cancels=[generate_event],
)
with gr.Tab("Debug"):
gr.Interface(
fn=debug,
inputs=[gr.Code(
label=debug.__doc__,
language="python",
interactive=True,
wrap_lines=True,
)],
outputs=[gr.Code(
language="python",
wrap_lines=True,
)],
)
demo.launch() |