|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
MODEL_OPTIONS = { |
|
|
"Mistral-7B-Instruct": "mistralai/Mistral-7B-Instruct-v0.1", |
|
|
"Qwen2.5-3B-Instruct": "Qwen/Qwen2.5-3B-Instruct", |
|
|
"Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct", |
|
|
"StableLM2-1.6B": "stabilityai/stablelm-2-zephyr-1_6b", |
|
|
"SmolLM3-3B": "HuggingFaceTB/SmolLM3-3B", |
|
|
"BTLM-3B-8k-base": "cerebras/btlm-3b-8k-base" |
|
|
} |
|
|
|
|
|
loaded = {} |
|
|
SYSTEM_PROMPT = "You are HugginGPT — helpful, friendly, and clear with memory." |
|
|
|
|
|
def load_model(model_key): |
|
|
model_id = MODEL_OPTIONS[model_key] |
|
|
if model_key in loaded: |
|
|
return loaded[model_key] |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
|
|
|
loaded[model_key] = (tokenizer, model) |
|
|
return tokenizer, model |
|
|
|
|
|
def generate_response(message, history, model_choice): |
|
|
tokenizer, model = load_model(model_choice) |
|
|
|
|
|
|
|
|
context = f"system: {SYSTEM_PROMPT}\n" |
|
|
if history: |
|
|
for u, a in history: |
|
|
context += f"user: {u}\nassistant: {a}\n" |
|
|
context += f"user: {message}\nassistant:" |
|
|
|
|
|
inputs = tokenizer(context, return_tensors="pt").to(model.device) |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=200, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
temperature=0.8 |
|
|
) |
|
|
text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
reply = text.split("assistant:")[-1].strip() |
|
|
return reply |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## HugginGPT") |
|
|
|
|
|
model_selector = gr.Dropdown( |
|
|
choices=list(MODEL_OPTIONS.keys()), |
|
|
value="Mistral-7B-Instruct", |
|
|
label="Select model" |
|
|
) |
|
|
|
|
|
chat = gr.ChatInterface( |
|
|
fn=lambda message, history: generate_response(message, history, model_selector.value), |
|
|
title="HugginGPT" |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|