File size: 1,707 Bytes
3ca6713
 
4e24e61
3ca6713
 
4e24e61
3ca6713
4e24e61
3ca6713
 
4e24e61
3ca6713
 
 
 
 
 
 
 
 
4e24e61
3ca6713
 
4e24e61
3ca6713
4e24e61
3ca6713
4e24e61
3ca6713
 
 
 
 
 
 
 
 
4e24e61
 
3ca6713
 
 
 
 
4e24e61
3ca6713
 
 
 
4e24e61
 
3ca6713
4e24e61
3ca6713
4e24e61
3ca6713
 
 
 
4e24e61
 
 
 
3ca6713
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread

MODEL_NAME = "S1mp1eXXX/Nimi-1b-thinking"

# Load once at startup (important)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)

def respond(message, history, system_message, max_tokens, temperature, top_p):

    messages = system_message + "\n"

    for h in history:
        messages += f"{h['role']}: {h['content']}\n"

    messages += f"user: {message}\nassistant:"

    inputs = tokenizer(messages, return_tensors="pt").to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True
    )

    generation_kwargs = dict(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        streamer=streamer
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    partial_output = ""
    for new_token in streamer:
        partial_output += new_token
        yield partial_output


chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="You are a helpful assistant.", label="System message"),
        gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
    ],
)

if __name__ == "__main__":
    chatbot.launch()