Marcus719 commited on
Commit
d9a1250
·
verified ·
1 Parent(s): 2a8403d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -24
app.py CHANGED
@@ -1,30 +1,260 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
- model_name = "Marcus719/Llama-3.2-3B-Instruct-Lab2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # load tokenizer and model
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_name,
10
- low_cpu_mem_usage=True,
11
- device_map="auto"
12
- )
13
 
14
- # define generate function
15
- def generate_text(input_text):
16
- inputs = tokenizer(input_text, return_tensors="pt")
17
- outputs = model.generate(inputs["input_ids"], max_length=100, num_return_sequences=1)
18
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
19
-
20
- # create gradio interface
21
- interface = gr.Interface(
22
- fn=generate_text,
23
- inputs="text",
24
- outputs="text",
25
- title="Hugging Face model Demo",
26
- description="say something"
27
  )
28
 
29
- # launch the app
30
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
+ from huggingface_hub import hf_hub_download
4
+ from llama_cpp import Llama
5
+ import os
6
+
7
+ # ============================================
8
+ # 配置区域 - KTH ID2223 Lab 2
9
+ # ============================================
10
+ MODEL_REPO = "Marcus719/Llama-3.2-3B-Instruct-FineTome-Lab2-GGUF"
11
+ # ⚠️ 请确认你仓库中的 GGUF 文件名,常见格式:
12
+ # - unsloth.Q4_K_M.gguf (推荐,较小较快)
13
+ # - unsloth.Q8_0.gguf (更精确但较慢)
14
+ MODEL_FILENAME = "unsloth.Q4_K_M.gguf"
15
+
16
+ # ============================================
17
+ # 下载并加载模型
18
+ # ============================================
19
+ print(f"📥 Downloading model from {MODEL_REPO}...")
20
+
21
+ try:
22
+ model_path = hf_hub_download(
23
+ repo_id=MODEL_REPO,
24
+ filename=MODEL_FILENAME,
25
+ )
26
+ print(f"✅ Model downloaded: {model_path}")
27
+ except Exception as e:
28
+ print(f"❌ Error downloading model: {e}")
29
+ print("Please check MODEL_FILENAME matches your repository file.")
30
+ raise e
31
+
32
+ print("🔄 Loading model (this may take a few minutes on CPU)...")
33
+
34
 
 
 
 
 
 
 
 
35
 
36
+ # 加载 GGUF 模型 - 针对 HuggingFace Spaces 免费 CPU 优化
37
+ llm = Llama(
38
+ model_path=model_path,
39
+ n_ctx=2048, # 上下文长度 (降低以节省内存)
40
+ n_threads=2, # HF Spaces 免费 CPU 线程数
41
+ n_gpu_layers=0, # 纯 CPU 推理
42
+ verbose=False
 
 
 
 
 
 
43
  )
44
 
45
+ print("✅ Model loaded successfully!")
46
+
47
+ # ============================================
48
+ # Llama 3.2 Instruct 对话模板
49
+ # ============================================
50
+ def format_prompt(message: str, history: list, system_prompt: str) -> str:
51
+ """
52
+ Format conversation using Llama 3.2 Instruct chat template.
53
+ Reference: https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_2
54
+ """
55
+ prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>"
56
+
57
+ # Add conversation history
58
+ for user_msg, assistant_msg in history:
59
+ if user_msg:
60
+ prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
61
+ if assistant_msg:
62
+ prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
63
+
64
+ # Add current message
65
+ prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|>"
66
+ prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n"
67
+
68
+ return prompt
69
+
70
+ # ============================================
71
+ # 生成回复函数 (流式输出)
72
+ # ============================================
73
+ def chat(message: str, history: list, system_prompt: str, max_tokens: int, temperature: float, top_p: float):
74
+ """Generate streaming response from the fine-tuned LLM."""
75
+
76
+ prompt = format_prompt(message, history, system_prompt)
77
+
78
+ response = ""
79
+ stream = llm(
80
+ prompt,
81
+ max_tokens=max_tokens,
82
+ temperature=temperature,
83
+ top_p=top_p,
84
+ stop=["<|eot_id|>", "<|end_of_text|>"],
85
+ stream=True
86
+ )
87
+
88
+ for chunk in stream:
89
+ token = chunk["choices"][0]["text"]
90
+ response += token
91
+ yield response
92
+
93
+ # ============================================
94
+ # Gradio 界面
95
+ # ============================================
96
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
97
+
98
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
99
+
100
+ with gr.Blocks(
101
+ theme=gr.themes.Soft(),
102
+ title="🦙 Llama 3.2 Fine-tuned ChatBot | KTH ID2223 Lab 2"
103
+ ) as demo:
104
+
105
+ gr.Markdown(
106
+ """
107
+ # 🦙 Llama 3.2 3B Instruct - Fine-tuned on FineTome Dataset
108
+
109
+ **KTH ID2223 Scalable Machine Learning - Lab 2**
110
+
111
+ This chatbot uses a Llama 3.2 3B model fine-tuned on the [FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) instruction dataset using QLoRA (4-bit quantization with LoRA adapters).
112
+
113
+ 📦 **Model**: [Marcus719/Llama-3.2-3B-Instruct-FineTome-Lab2-GGUF](https://huggingface.co/Marcus719/Llama-3.2-3B-Instruct-FineTome-Lab2-GGUF)
114
+ """
115
+ )
116
+
117
+ chatbot = gr.Chatbot(
118
+ label="Conversation",
119
+ height=450,
120
+ show_copy_button=True,
121
+ )
122
+
123
+ with gr.Row():
124
+ msg = gr.Textbox(
125
+ label="Your Message",
126
+ placeholder="Type your message here...",
127
+ scale=4,
128
+ container=False,
129
+ autofocus=True
130
+ )
131
+ submit_btn = gr.Button("Send 🚀", scale=1, variant="primary")
132
+
133
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
134
+ system_prompt = gr.Textbox(
135
+ label="System Prompt",
136
+ value=DEFAULT_SYSTEM_PROMPT,
137
+ lines=4
138
+ )
139
+ with gr.Row():
140
+ max_tokens = gr.Slider(
141
+ minimum=64,
142
+ maximum=1024,
143
+ value=256,
144
+ step=32,
145
+ label="Max Tokens"
146
+ )
147
+ temperature = gr.Slider(
148
+ minimum=0.1,
149
+ maximum=1.5,
150
+ value=0.7,
151
+ step=0.1,
152
+ label="Temperature"
153
+ )
154
+ top_p = gr.Slider(
155
+ minimum=0.1,
156
+ maximum=1.0,
157
+ value=0.9,
158
+ step=0.05,
159
+ label="Top-p"
160
+ )
161
+
162
+ with gr.Row():
163
+ clear_btn = gr.Button("🗑️ Clear Chat")
164
+ retry_btn = gr.Button("🔄 Regenerate")
165
+
166
+ # Example prompts
167
+ gr.Examples(
168
+ examples=[
169
+ "Hello! Can you introduce yourself?",
170
+ "Explain machine learning in simple terms.",
171
+ "What is the difference between supervised and unsupervised learning?",
172
+ "Write a short poem about artificial intelligence.",
173
+ "How does fine-tuning improve a language model?",
174
+ ],
175
+ inputs=msg,
176
+ label="💡 Example Prompts"
177
+ )
178
+
179
+ # Event handlers
180
+ def user_input(message, history):
181
+ return "", history + [[message, None]]
182
+
183
+ def bot_response(history, system_prompt, max_tokens, temperature, top_p):
184
+ if not history:
185
+ return history
186
+ message = history[-1][0]
187
+ history_for_model = history[:-1]
188
+
189
+ for response in chat(message, history_for_model, system_prompt, max_tokens, temperature, top_p):
190
+ history[-1][1] = response
191
+ yield history
192
+
193
+ def retry_last(history, system_prompt, max_tokens, temperature, top_p):
194
+ if history and len(history) > 0:
195
+ history[-1][1] = None
196
+ message = history[-1][0]
197
+ history_for_model = history[:-1]
198
+
199
+ for response in chat(message, history_for_model, system_prompt, max_tokens, temperature, top_p):
200
+ history[-1][1] = response
201
+ yield history
202
+
203
+ # Submit message
204
+ msg.submit(
205
+ user_input,
206
+ [msg, chatbot],
207
+ [msg, chatbot],
208
+ queue=False
209
+ ).then(
210
+ bot_response,
211
+ [chatbot, system_prompt, max_tokens, temperature, top_p],
212
+ chatbot
213
+ )
214
+
215
+ submit_btn.click(
216
+ user_input,
217
+ [msg, chatbot],
218
+ [msg, chatbot],
219
+ queue=False
220
+ ).then(
221
+ bot_response,
222
+ [chatbot, system_prompt, max_tokens, temperature, top_p],
223
+ chatbot
224
+ )
225
+
226
+ # Clear chat
227
+ clear_btn.click(lambda: [], None, chatbot, queue=False)
228
+
229
+ # Retry last response
230
+ retry_btn.click(
231
+ retry_last,
232
+ [chatbot, system_prompt, max_tokens, temperature, top_p],
233
+ chatbot
234
+ )
235
+
236
+ gr.Markdown(
237
+ """
238
+ ---
239
+ ### 📝 About This Project
240
+
241
+ **Fine-tuning Details:**
242
+ - Base Model: `meta-llama/Llama-3.2-3B-Instruct`
243
+ - Dataset: [FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k)
244
+ - Method: QLoRA (4-bit quantization + LoRA)
245
+ - Framework: [Unsloth](https://github.com/unslothai/unsloth)
246
+
247
+ **Tips:**
248
+ - Lower temperature (0.1-0.5) for more focused responses
249
+ - Higher temperature (0.7-1.0) for creative responses
250
+ - Adjust max tokens based on expected response length
251
+
252
+ Built with ❤️ using Gradio & llama.cpp | KTH ID2223 Lab 2
253
+ """
254
+ )
255
+
256
+ # ============================================
257
+ # Launch
258
+ # ============================================
259
+ if __name__ == "__main__":
260
+ demo.queue().launch()