import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache import os import spaces # Set the device for model inference # This will automatically use the GPU if one is available and configured device = "cuda" if torch.cuda.is_available() else "cpu" MODEL = "pszemraj/medgemma-27b-text-heretic_med" # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL, local_files_only=True) # Load the model and ensure it's on the correct device # We've added `load_in_8bit=True` to reduce the memory footprint. # We've also added `offload_folder` to explicitly enable disk offloading # for the model when it can't fit into VRAM or system RAM. model = AutoModelForCausalLM.from_pretrained( MODEL, dtype=torch.bfloat16, device_map="auto", # load_in_8bit=True, offload_folder="./offload_dir", local_files_only=True, ) if False: def chat_interface(message, history): """ Main chat function to interact with the model. """ chat_history = list(history) # Add the current user message to the chat history chat_history.append({"role": "user", "content": message}) # Apply the tokenizer's chat template prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True) # Generate the response input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") outputs = model.generate( input_ids.to(device), max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95 ) # Decode the response and extract the model's part response = tokenizer.decode(outputs[0]) response = response.split("")[1].strip() return response # Create the Gradio interface gr.ChatInterface( fn=chat_interface, type="messages", title="MedGemma-4B-IT Medical Assistant", description="A fine-tuned model for medical-related questions." ).launch(share=True) @spaces.GPU(duration=60) def extend(text, max_new_tokens, chunk_size, progress=gr.Progress()): PREFIX = "\n" # Model just repeats the last token without this progress(0, desc="Tokenizing...") token_ids = tokenizer.encode(PREFIX + text, add_special_tokens=False, return_tensors="pt") past_key_values = DynamicCache(config=model.config) done_tokens = 0 try: # Generate in loop to allow it to be interrupted while done_tokens < max_new_tokens: progress(done_tokens / max_new_tokens, desc="Generating...") chunk_max_new_tokens = min(chunk_size, max_new_tokens - done_tokens) new_ids = model.generate( token_ids.to(device), max_new_tokens=chunk_max_new_tokens, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, past_key_values=past_key_values, # continue from where we left off ) chunk_new_tokens = new_ids.shape[1] - token_ids.shape[1] if chunk_new_tokens < chunk_max_new_tokens: break # Model decided to stop early done_tokens += chunk_new_tokens token_ids = new_ids (unwrapped_new_ids,) = new_ids new_text = tokenizer.decode(unwrapped_new_ids).removeprefix(PREFIX) if not new_text.startswith(text): yield text, "New text somehow deleted existing text!\n\n" + new_text return yield new_text, f"New tokens generated: {done_tokens}/{max_new_tokens}" except Exception as e: yield text, f"# ERROR: {e!r}" DEBUG_ENABLED = False if DEBUG_ENABLED: def debug(cmd): """Run `result.append(...)` to display values.""" result = [] exec(cmd, globals(), locals()) return repr(result) else: def debug(x): """Debug print the input.""" return repr(x) with gr.Blocks() as demo: gr.Markdown("# Medical Text Generation") gr.Markdown(f"Model in use: {MODEL}") with gr.Tab("Extend"): gr.Markdown("Enter some medical text, and press Generate to continue it.") gr.Markdown("To allow interrupting the generation, it occurs in chunks, remembering the KV cache (only during the generation, not currently across executions).") gr.Markdown("Raising the chunk size will increase latency, but might make it go faster by reducing overhead.") document = gr.Code( language="markdown", interactive=True, wrap_lines=True, ) max_new_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=8192, step=10, value=128) chunk_size = gr.Slider(label="Streaming Chunk Size", minimum=1, maximum=100, step=1, value=5) with gr.Row(): generate_button = gr.Button("Generate") abort_button = gr.Button("Abort") generate_event = generate_button.click( fn=extend, inputs=[ document, max_new_tokens, chunk_size, ], outputs=[ document, gr.Code( label="Status", language="markdown", interactive=False, wrap_lines=True, ), ], show_progress="minimal", ) abort_button.click( fn=None, inputs=None, outputs=None, cancels=[generate_event], ) with gr.Tab("Debug"): gr.Interface( fn=debug, inputs=[gr.Code( label=debug.__doc__, language="python", interactive=True, wrap_lines=True, )], outputs=[gr.Code( language="python", wrap_lines=True, )], ) demo.launch()