File size: 2,325 Bytes
4aadb97
8441218
4eb81a9
8441218
 
 
4eb81a9
75fbb74
8441218
 
54d1587
 
8441218
4aadb97
 
 
 
 
 
 
 
 
75fbb74
4aadb97
 
54d1587
8441218
4aadb97
75fbb74
8441218
 
 
 
 
 
75fbb74
 
 
 
 
 
4aadb97
54d1587
4aadb97
8441218
 
75fbb74
8441218
 
 
75fbb74
 
 
 
8441218
4aadb97
54d1587
 
4aadb97
8441218
54d1587
8441218
 
4aadb97
 
 
 
 
 
 
75fbb74
4aadb97
 
 
 
 
 
 
 
 
 
 
 
 
54d1587
4aadb97
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM

MODEL_ID = "vietrix/viena-60m"

tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID, legacy=True)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    device_map="cpu",
)


def respond(
    message,
    history: list[dict[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    hf_token: gr.OAuthToken, 
):
    messages = [{"role": "system", "content": system_message}]
    messages.extend(history)
    messages.append({"role": "user", "content": message})

    if getattr(tokenizer, "chat_template", None):
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
    else:
        # fallback rất đơn giản
        parts = []
        for m in messages:
            parts.append(f"{m['role'].upper()}: {m['content']}")
        parts.append("ASSISTANT:")
        prompt = "\n".join(parts)

    inputs = tokenizer(prompt, return_tensors="pt")

    outputs = model.generate(
        **inputs,
        max_new_tokens=int(max_tokens),
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        repetition_penalty=1.15,
        no_repeat_ngram_size=4,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
    )

    gen_ids = outputs[0, inputs.input_ids.shape[1]:]
    text = tokenizer.decode(gen_ids, skip_special_tokens=True)

    resp = ""
    for ch in text:
        resp += ch
        yield resp


chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

with gr.Blocks() as demo:
    with gr.Sidebar():
        gr.LoginButton()
    chatbot.render()

if __name__ == "__main__":
    demo.launch()