|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
ADAPTER_ID = "GhostScientist/smollm2-360m-function-calling-sft" |
|
|
|
|
|
BASE_MODEL_ID = "HuggingFaceTB/SmolLM2-360M-Instruct" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
def load_model(): |
|
|
global model |
|
|
if model is None: |
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
) |
|
|
model = PeftModel.from_pretrained(base_model, ADAPTER_ID) |
|
|
model = model.merge_and_unload() |
|
|
return model |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate_response(message, history, system_message, max_tokens, temperature, top_p): |
|
|
model = load_model() |
|
|
|
|
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
|
|
|
|
|
for item in history: |
|
|
if isinstance(item, dict): |
|
|
messages.append({"role": item["role"], "content": item["content"]}) |
|
|
elif isinstance(item, (list, tuple)) and len(item) == 2: |
|
|
|
|
|
user_msg, assistant_msg = item |
|
|
if user_msg: |
|
|
messages.append({"role": "user", "content": user_msg}) |
|
|
if assistant_msg: |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=int(max_tokens), |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
response = tokenizer.decode( |
|
|
outputs[0][inputs['input_ids'].shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
return response |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
generate_response, |
|
|
title="SmolLM2 360M Function Calling", |
|
|
description="A LoRA fine-tuned SmolLM2-360M model for function calling, powered by ZeroGPU (free!)", |
|
|
additional_inputs=[ |
|
|
gr.Textbox( |
|
|
value="You are a helpful assistant that can call functions when needed.", |
|
|
label="System message", |
|
|
lines=2 |
|
|
), |
|
|
gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Max tokens"), |
|
|
gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature"), |
|
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"), |
|
|
], |
|
|
examples=[ |
|
|
["Hello! What can you help me with?"], |
|
|
["What's the weather like in San Francisco?"], |
|
|
["Can you search for the latest news about AI?"], |
|
|
], |
|
|
type="messages", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|