File size: 2,191 Bytes
aed97e1
b83ade3
 
aed97e1
b83ade3
 
aed97e1
b83ade3
a9c3fad
b83ade3
 
 
 
 
 
860282d
55ea8b6
aed97e1
 
 
 
 
 
 
b83ade3
aed97e1
b83ade3
aed97e1
 
 
 
860282d
 
b83ade3
860282d
aed97e1
b83ade3
 
 
 
 
 
 
aed97e1
 
b83ade3
 
860282d
aed97e1
b83ade3
 
 
 
 
 
aed97e1
 
860282d
aed97e1
 
 
 
860282d
b83ade3
 
aed97e1
 
 
 
 
860282d
aed97e1
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login

# Hugging Face model repo ID (must contain HF model weights, NOT .gguf)
MODEL_ID = "Selinaliu1030/lora_model"

# Load tokenizer + model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",            # uses GPU if available
    torch_dtype="auto",           # automatically picks fp16/bf16
    low_cpu_mem_usage=True,
)

def respond(
    message,
    history: list[dict[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    hf_token,   # still required by UI signature; unused
):
    # Build prompt
    messages = [{"role": "system", "content": system_message}]
    messages.extend(history)
    messages.append({"role": "user", "content": message})

    prompt = ""
    for msg in messages:
        prompt += f"<{msg['role']}>: {msg['content']}\n"
    prompt += "<assistant>: "

    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Generate
    output = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    # Decode
    result = tokenizer.decode(output[0], skip_special_tokens=True)
    # Extract only the assistant's response
    assistant_reply = result.split("<assistant>:")[-1].strip()

    yield assistant_reply


# Gradio UI
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="You are a helpful assistant.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p",
        ),
    ],
)

with gr.Blocks() as demo:
    with gr.Sidebar():
        gr.LoginButton()
    chatbot.render()

if __name__ == "__main__":
    demo.launch()