SmolLM3-3B / app.py
BytArch's picture
Update app.py
6d8c5f5 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load the model
model_path = "HuggingFaceTB/SmolLM3-3B"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
SYSTEM_PROMPT = (
"You are smollm3-3b, an AI assistant built on the SmolLM3 architecture. "
"You are direct, efficient, and helpful.\n\n"
"**Identity:** Concise and human-like conversation.\n\n"
"**Response Rules:** "
"- Give ONE clear, relevant response per query. "
"- Stay strictly on topic. No tangents, filler, or repetition. "
"- If unsure, say you don’t know instead of guessing. "
"- Do not hallucinate information.\n\n"
"**Response Style:** "
"- Short, clear, natural. "
"- Prioritize brevity and sense over detail. "
"- Friendly by default, formal if asked.\n\n"
"**Capabilities:** "
"- Answer any topic sensibly."
)
def build_context(user_message):
return SYSTEM_PROMPT + "\n\nUser: " + user_message + "\n\nAssistant:"
def generate_response(
prompt,
max_tokens=300,
temperature=0.45,
top_p=0.95,
repetition_penalty=1.1,
top_k=35,
):
formatted_input = build_context(prompt)
inputs = tokenizer(
formatted_input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
)
# Take only newly generated tokens
new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
# Remove leftover special tokens
response = response.replace("<|im_end|>", "").strip()
lines = response.splitlines()
first_line = lines[0].strip() if lines else ""
for label in ["Assistant:", "assistant:"]:
if first_line.lower().startswith(label.lower()):
return first_line[len(label):].strip()
cleaned_lines = []
for line in lines:
for label in ["Assistant:", "assistant:", "User:", "user:"]:
if label in line:
line = line.split(label)[0].strip()
if line:
cleaned_lines.append(line)
return "\n".join(cleaned_lines)
def respond(
message,
history,
max_tokens,
temperature,
top_p,
repetition_penalty,
top_k,
):
return generate_response(
message,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
top_k=top_k,
)
chatbot = gr.ChatInterface(
respond,
type="messages",
title="Chat with smollm3-3b",
description="Open-source AI model, beta, 0 restrictions, answers all topics.",
additional_inputs=[
gr.Slider(minimum=25, maximum=500, value=50, step=10, label="Max new tokens"),
gr.Slider(minimum=0.01, maximum=1.0, value=0.2, step=0.01, label="Temperature"),
gr.Slider(minimum=0.5, maximum=1.0, value=0.9, step=0.01, label="Top-p (nucleus sampling)"),
gr.Slider(minimum=1.0, maximum=1.5, value=1.1, step=0.001, label="Repetition penalty"),
gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Top-k (prediction sampling)"),
],
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
chatbot.render()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True, mcp_server=True)