smol-llm / app.py
Abeersherif's picture
Update app.py
6be7985 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ----------------------------------------------------
# LOAD YOUR FINE–TUNED MODEL (LOCAL)
# ----------------------------------------------------
MODEL_PATH = "smol-medical-meadow-FT" # change if your folder name is different
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.float32,
)
model.config.pad_token_id = tokenizer.eos_token_id
model.config.use_cache = False # safer for smaller models
# ----------------------------------------------------
# CHAT FUNCTION (LOCAL GENERATION)
# ----------------------------------------------------
def respond(message, history, system_message, max_tokens, temperature, top_p):
# Convert gradio history to simple text conversation
conversation = system_message + "\n"
for turn in history:
conversation += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
# Current user message
prompt = conversation + f"User: {message}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output_stream = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
)
# Decode only the assistant's generated part
generated = output_stream[0][inputs["input_ids"].shape[1]:]
answer = tokenizer.decode(generated, skip_special_tokens=True).strip()
yield answer
# ----------------------------------------------------
# GRADIO UI
# ----------------------------------------------------
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="You are a helpful medical assistant.", label="System message"),
gr.Slider(minimum=10, maximum=512, value=150, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
],
)
demo = gr.Blocks()
with demo:
chatbot.render()
if __name__ == "__main__":
demo.launch()