File size: 5,175 Bytes
e280430
 
 
39de6aa
 
 
 
 
e280430
39de6aa
e280430
39de6aa
 
 
e280430
39de6aa
 
 
 
 
 
 
 
 
e280430
39de6aa
 
e280430
39de6aa
e280430
 
 
 
 
 
39de6aa
 
 
 
e280430
39de6aa
 
 
 
e280430
39de6aa
 
 
 
c65e984
39de6aa
 
 
c65e984
39de6aa
c65e984
39de6aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e280430
 
39de6aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
from huggingface_hub import InferenceClient

# ============================================
# KTH ID2223 Lab 2 - Llama 3.2 ChatBot
# ============================================
# 使用你的微调模型(safetensors 格式)
MODEL_ID = "Marcus719/Llama-3.2-3B-Instruct-Lab2"

client = InferenceClient(model=MODEL_ID)

def chat(message, history, system_message, max_tokens, temperature, top_p):
    """Generate response using HuggingFace Inference API"""
    
    messages = [{"role": "system", "content": system_message}]
    
    # Add conversation history
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    
    # Add current message
    messages.append({"role": "user", "content": message})
    
    # Stream response
    response = ""
    for chunk in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        if chunk.choices and chunk.choices[0].delta.content:
            token = chunk.choices[0].delta.content
            response += token
            yield response

# ============================================
# Gradio 界面
# ============================================
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."

with gr.Blocks(theme=gr.themes.Soft(), title="🦙 Llama 3.2 ChatBot") as demo:
    
    gr.Markdown(
        """
        # 🦙 Llama 3.2 3B Instruct - Fine-tuned on FineTome
        
        **KTH ID2223 Scalable Machine Learning - Lab 2**
        
        This chatbot uses my fine-tuned Llama 3.2 3B model trained on the FineTome-100k dataset.
        
        📦 Model: [Marcus719/Llama-3.2-3B-Instruct-Lab2](https://huggingface.co/Marcus719/Llama-3.2-3B-Instruct-Lab2)
        """
    )
    
    chatbot = gr.Chatbot(label="Chat", height=450, show_copy_button=True)
    
    with gr.Row():
        msg = gr.Textbox(
            placeholder="Type your message here...",
            scale=4,
            container=False,
            autofocus=True
        )
        submit_btn = gr.Button("Send 🚀", scale=1, variant="primary")
    
    with gr.Accordion("⚙️ Settings", open=False):
        system_prompt = gr.Textbox(
            label="System Prompt",
            value=DEFAULT_SYSTEM_PROMPT,
            lines=2
        )
        with gr.Row():
            max_tokens = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens")
            temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
    
    with gr.Row():
        clear_btn = gr.Button("🗑️ Clear Chat")
        retry_btn = gr.Button("🔄 Regenerate")
    
    gr.Examples(
        examples=[
            "Hello! Can you introduce yourself?",
            "Explain machine learning in simple terms.",
            "What is the difference between fine-tuning and pre-training?",
            "Write a short poem about AI.",
        ],
        inputs=msg,
        label="💡 Try these examples"
    )
    
    # Event handlers
    def user_input(message, history):
        return "", history + [[message, None]]
    
    def bot_response(history, system_prompt, max_tokens, temperature, top_p):
        if not history:
            return history
        message = history[-1][0]
        history_for_model = history[:-1]
        for response in chat(message, history_for_model, system_prompt, max_tokens, temperature, top_p):
            history[-1][1] = response
            yield history
    
    def retry_last(history, system_prompt, max_tokens, temperature, top_p):
        if history:
            history[-1][1] = None
            message = history[-1][0]
            history_for_model = history[:-1]
            for response in chat(message, history_for_model, system_prompt, max_tokens, temperature, top_p):
                history[-1][1] = response
                yield history
    
    msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot_response, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot
    )
    submit_btn.click(user_input, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot_response, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot
    )
    clear_btn.click(lambda: [], None, chatbot, queue=False)
    retry_btn.click(retry_last, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot)
    
    gr.Markdown(
        """
        ---
        ### 📝 About This Project
        
        **Fine-tuning Details:**
        - Base Model: `meta-llama/Llama-3.2-3B-Instruct`
        - Dataset: [FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k)
        - Method: QLoRA (4-bit quantization + LoRA)
        - Framework: [Unsloth](https://github.com/unslothai/unsloth)
        
        Built with ❤️ for KTH ID2223 Lab 2
        """
    )

if __name__ == "__main__":
    demo.launch()