import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer # ---------------------------------------------------- # LOAD YOUR FINE–TUNED MODEL (LOCAL) # ---------------------------------------------------- MODEL_PATH = "smol-medical-meadow-FT" # change if your folder name is different tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", torch_dtype=torch.float32, ) model.config.pad_token_id = tokenizer.eos_token_id model.config.use_cache = False # safer for smaller models # ---------------------------------------------------- # CHAT FUNCTION (LOCAL GENERATION) # ---------------------------------------------------- def respond(message, history, system_message, max_tokens, temperature, top_p): # Convert gradio history to simple text conversation conversation = system_message + "\n" for turn in history: conversation += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" # Current user message prompt = conversation + f"User: {message}\nAssistant:" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) output_stream = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, eos_token_id=tokenizer.eos_token_id, ) # Decode only the assistant's generated part generated = output_stream[0][inputs["input_ids"].shape[1]:] answer = tokenizer.decode(generated, skip_special_tokens=True).strip() yield answer # ---------------------------------------------------- # GRADIO UI # ---------------------------------------------------- chatbot = gr.ChatInterface( respond, type="messages", additional_inputs=[ gr.Textbox(value="You are a helpful medical assistant.", label="System message"), gr.Slider(minimum=10, maximum=512, value=150, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"), ], ) demo = gr.Blocks() with demo: chatbot.render() if __name__ == "__main__": demo.launch()