File size: 1,312 Bytes
091ab61
ed38f96
091ab61
b7a2e29
091ab61
79e9f27
091ab61
ed38f96
 
 
 
 
 
 
091ab61
ed38f96
 
 
 
 
7d76eaa
ed38f96
 
 
 
 
 
 
 
 
 
 
 
7aa0a0d
ed38f96
 
 
 
 
 
 
600b915
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import accelerate

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)

def generate_reply(user_input, history):
    prompt = "<|user|>\n" + user_input.strip() + "\n<|assistant|>\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        top_k=50,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    reply = response.split("<|assistant|>")[-1].strip()
    history.append((user_input, reply))
    return history, history

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("See my Open-Source Chatbot")
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Type a question and press Enter")
    state = gr.State([])

    msg.submit(generate_reply, [msg, state], [chatbot, state])
    gr.Button("Clear").click(lambda: ([], []), None, [chatbot, state])

demo.launch()