Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache | |
| import os | |
| import spaces | |
| # Set the device for model inference | |
| # This will automatically use the GPU if one is available and configured | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL = "pszemraj/medgemma-27b-text-heretic_med" | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL, local_files_only=True) | |
| # Load the model and ensure it's on the correct device | |
| # We've added `load_in_8bit=True` to reduce the memory footprint. | |
| # We've also added `offload_folder` to explicitly enable disk offloading | |
| # for the model when it can't fit into VRAM or system RAM. | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL, | |
| dtype=torch.bfloat16, | |
| device_map="auto", | |
| # load_in_8bit=True, | |
| offload_folder="./offload_dir", | |
| local_files_only=True, | |
| ) | |
| if False: | |
| def chat_interface(message, history): | |
| """ | |
| Main chat function to interact with the model. | |
| """ | |
| chat_history = list(history) | |
| # Add the current user message to the chat history | |
| chat_history.append({"role": "user", "content": message}) | |
| # Apply the tokenizer's chat template | |
| prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True) | |
| # Generate the response | |
| input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | |
| outputs = model.generate( | |
| input_ids.to(device), | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95 | |
| ) | |
| # Decode the response and extract the model's part | |
| response = tokenizer.decode(outputs[0]) | |
| response = response.split("<end_of_turn>")[1].strip() | |
| return response | |
| # Create the Gradio interface | |
| gr.ChatInterface( | |
| fn=chat_interface, | |
| type="messages", | |
| title="MedGemma-4B-IT Medical Assistant", | |
| description="A fine-tuned model for medical-related questions." | |
| ).launch(share=True) | |
| def extend(text, max_new_tokens, chunk_size, progress=gr.Progress()): | |
| PREFIX = "<bos>\n" # Model just repeats the last token without this | |
| progress(0, desc="Tokenizing...") | |
| token_ids = tokenizer.encode(PREFIX + text, add_special_tokens=False, return_tensors="pt") | |
| past_key_values = DynamicCache(config=model.config) | |
| done_tokens = 0 | |
| try: | |
| # Generate in loop to allow it to be interrupted | |
| while done_tokens < max_new_tokens: | |
| progress(done_tokens / max_new_tokens, desc="Generating...") | |
| chunk_max_new_tokens = min(chunk_size, max_new_tokens - done_tokens) | |
| new_ids = model.generate( | |
| token_ids.to(device), | |
| max_new_tokens=chunk_max_new_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| past_key_values=past_key_values, # continue from where we left off | |
| ) | |
| chunk_new_tokens = new_ids.shape[1] - token_ids.shape[1] | |
| if chunk_new_tokens < chunk_max_new_tokens: | |
| break # Model decided to stop early | |
| done_tokens += chunk_new_tokens | |
| token_ids = new_ids | |
| (unwrapped_new_ids,) = new_ids | |
| new_text = tokenizer.decode(unwrapped_new_ids).removeprefix(PREFIX) | |
| if not new_text.startswith(text): | |
| yield text, "New text somehow deleted existing text!\n\n" + new_text | |
| return | |
| yield new_text, f"New tokens generated: {done_tokens}/{max_new_tokens}" | |
| except Exception as e: | |
| yield text, f"# ERROR: {e!r}" | |
| DEBUG_ENABLED = False | |
| if DEBUG_ENABLED: | |
| def debug(cmd): | |
| """Run `result.append(...)` to display values.""" | |
| result = [] | |
| exec(cmd, globals(), locals()) | |
| return repr(result) | |
| else: | |
| def debug(x): | |
| """Debug print the input.""" | |
| return repr(x) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Medical Text Generation") | |
| gr.Markdown(f"Model in use: {MODEL}") | |
| with gr.Tab("Extend"): | |
| gr.Markdown("Enter some medical text, and press Generate to continue it.") | |
| gr.Markdown("To allow interrupting the generation, it occurs in chunks, remembering the KV cache (only during the generation, not currently across executions).") | |
| gr.Markdown("Raising the chunk size will increase latency, but might make it go faster by reducing overhead.") | |
| document = gr.Code( | |
| language="markdown", | |
| interactive=True, | |
| wrap_lines=True, | |
| ) | |
| max_new_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=8192, step=10, value=128) | |
| chunk_size = gr.Slider(label="Streaming Chunk Size", minimum=1, maximum=100, step=1, value=5) | |
| with gr.Row(): | |
| generate_button = gr.Button("Generate") | |
| abort_button = gr.Button("Abort") | |
| generate_event = generate_button.click( | |
| fn=extend, | |
| inputs=[ | |
| document, | |
| max_new_tokens, | |
| chunk_size, | |
| ], | |
| outputs=[ | |
| document, | |
| gr.Code( | |
| label="Status", | |
| language="markdown", | |
| interactive=False, | |
| wrap_lines=True, | |
| ), | |
| ], | |
| show_progress="minimal", | |
| ) | |
| abort_button.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| cancels=[generate_event], | |
| ) | |
| with gr.Tab("Debug"): | |
| gr.Interface( | |
| fn=debug, | |
| inputs=[gr.Code( | |
| label=debug.__doc__, | |
| language="python", | |
| interactive=True, | |
| wrap_lines=True, | |
| )], | |
| outputs=[gr.Code( | |
| language="python", | |
| wrap_lines=True, | |
| )], | |
| ) | |
| demo.launch() |