File size: 3,503 Bytes
f4455e9
059c4aa
f4455e9
031ecb9
f4455e9
059c4aa
 
 
 
f4455e9
 
 
059c4aa
 
f4455e9
 
 
031ecb9
f4455e9
 
 
031ecb9
 
 
 
f4455e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
059c4aa
 
 
 
 
 
 
d0a99a2
f4455e9
059c4aa
 
 
 
 
 
 
 
 
 
 
f4455e9
059c4aa
f4455e9
d0a99a2
031ecb9
f4455e9
 
 
 
 
 
059c4aa
031ecb9
 
059c4aa
 
 
 
 
 
8b3dbaa
d0a99a2
e65a3d7
059c4aa
 
 
8b3dbaa
059c4aa
 
 
d0a99a2
 
 
 
 
 
 
f4455e9
 
 
e65a3d7
f4455e9
 
 
059c4aa
 
 
 
 
8b3dbaa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import queue
import gradio as gr
import torch
import threading
from transformers import AutoTokenizer, AutoModelForCausalLM

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
checkpoint = "LemiSt/SmolLM-135M-instruct-de-merged"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)


class CustomIterable:
    def __init__(self):
        self._queue = queue.Queue()  # Thread-safe queue
        self.first = True

    def put(self, item):
        """Add an element to the internal queue."""
        if self.first:
            self.first = False
        else:
            self._queue.put(item)

    def end(self):
        """Signal that no more elements will be added."""
        self._queue.put(None)  # Sentinel value to indicate the end of the queue

    def __iter__(self):
        """Return the iterator (self in this case)."""
        return self

    def __next__(self):
        """Return the next element from the queue, blocking if necessary."""
        try:
            item = self._queue.get(block=True)  # Wait for an item
        except queue.Empty:
            raise StopIteration

        if item is None:  # Sentinel value to end the iteration
            raise StopIteration

        return item

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    top_k,
    repetition_penalty
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    streamer = CustomIterable()

    inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True)
    thread = threading.Thread(target=model.generate, args=([inputs]), kwargs={"max_new_tokens": max_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "streamer": streamer})
    thread.start()
    response = ""

    for token in streamer:
        decoded = tokenizer.decode(token, skip_special_tokens=True)
        response += decoded
        yield response

    thread.join()

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Du bist ein hilfreicher Assistent.", label="System message"),
        gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.4, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.9,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(
            minimum=16,
            maximum=1024,
            value=512,
            step=1,
            label="Top-k",
        ),
        gr.Slider(
            minimum=0.1,
            maximum=2.0,
            value=1.1,
            step=0.05,
            label="Repetition penalty",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()