File size: 3,788 Bytes
5be55c9
27950ee
5be55c9
 
 
f8da60d
7f8a6b3
 
 
e9c6c9e
7f8a6b3
5be55c9
27950ee
 
 
 
 
 
 
 
f8da60d
 
53a81fa
 
27950ee
f8da60d
5be55c9
 
e9c6c9e
f8da60d
27950ee
 
 
e9c6c9e
 
 
27950ee
 
7f8a6b3
e9c6c9e
7f8a6b3
 
e9c6c9e
 
 
 
 
 
 
 
 
 
 
5be55c9
27950ee
 
 
5be55c9
f8da60d
 
 
 
7f8a6b3
f8da60d
 
 
7f8a6b3
f8da60d
 
 
5be55c9
 
 
 
 
e9c6c9e
 
5be55c9
7f8a6b3
 
5be55c9
e9c6c9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5be55c9
 
e9c6c9e
5be55c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import threading
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MODEL_ID = "Splashdude/smollm-chatbot"
SYSTEM_PROMPT = (
    "You are a helpful, friendly AI assistant. "
    "You give clear, accurate, and conversational answers. "
    "Remember what the user tells you in this conversation."
)

model = None
tokenizer = None


def load_model():
    global model, tokenizer
    if model is not None:
        return
    print("Loading model...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32)
    model.to("cpu")
    model.eval()
    print("Model loaded!")


def generate_response(message, chat_history):
    if model is None:
        try:
            load_model()
        except Exception as e:
            chat_history.append({"role": "user", "content": message})
            chat_history.append({"role": "assistant", "content": f"Error: {e}"})
            yield chat_history, ""
            return

    if not message or not message.strip():
        yield chat_history, ""
        return

    chat_history.append({"role": "user", "content": message})
    chat_history.append({"role": "assistant", "content": ""})

    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    for msg in chat_history[:-1]:
        messages.append({"role": msg["role"], "content": msg["content"]})

    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(text, return_tensors="pt")

    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    thread = threading.Thread(
        target=model.generate,
        kwargs={
            **inputs,
            "max_new_tokens": 512,
            "do_sample": True,
            "temperature": 0.7,
            "top_p": 0.9,
            "repetition_penalty": 1.1,
            "streamer": streamer,
        },
    )
    thread.start()

    partial = ""
    for token in streamer:
        partial += token
        chat_history[-1]["content"] = partial
        yield chat_history, ""

    thread.join()


def clear_chat():
    return [], ""


with gr.Blocks(title="AI Chatbot", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# AI Chatbot\nFast conversational AI powered by SmolLM2-360M.")

    chatbot = gr.Chatbot(type="messages", height=500, show_copy_button=True, label="Chat")
    chat_state = gr.State([])

    with gr.Row():
        msg = gr.Textbox(
            placeholder="Type your message...",
            show_label=False,
            container=False,
            scale=8,
        )
        submit = gr.Button("Send", variant="primary", scale=1)
        clear = gr.Button("New Chat", scale=1)

    gr.Examples(
        examples=[
            "Hello! How are you?",
            "Tell me a joke.",
            "What is the capital of France?",
            "Explain gravity in simple terms.",
        ],
        inputs=msg,
        label="Examples",
    )

    def user_submit(message, history):
        for updated_history, _ in generate_response(message, history):
            yield updated_history, "", updated_history

    def bot_response(message, history):
        for updated_history, _ in generate_response(message, history):
            yield updated_history, updated_history

    msg.submit(
        user_submit,
        [msg, chat_state],
        [chatbot, msg, chat_state],
        queue=True,
    )

    submit.click(
        user_submit,
        [msg, chat_state],
        [chatbot, msg, chat_state],
        queue=True,
    )

    clear.click(clear_chat, None, [chatbot, chat_state])

if __name__ == "__main__":
    demo.queue()
    demo.launch()