Marcus719 commited on
Commit
9ebaef7
·
verified ·
1 Parent(s): 078bd3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -143
app.py CHANGED
@@ -1,166 +1,226 @@
1
- import gradio as gr
2
- import time
3
  import os
4
- # from llama_cpp import Llama # Uncomment if running locally with the library installed
5
- import numpy as np
 
 
6
 
7
- # --- CONFIGURATION ---
8
- GGUF_MODEL_PATH_1B = "./llama-3.2-1b-summary-q4_k_m.gguf"
9
- GGUF_MODEL_PATH_3B = "./llama-3.2-3b-summary-q4_k_m.gguf"
 
 
 
10
 
11
- SYSTEM_PROMPT = (
12
- "You are an expert summarization bot. Your task is to provide a comprehensive "
13
- "and concise summary of the user's document based on the requested length."
14
- )
 
15
 
16
- # ----------------------------------------------------
17
- # 1. MODEL LOADING FUNCTION
18
- # ----------------------------------------------------
19
- # Note: For demonstration purposes, I am keeping your logic structure.
20
- # Ensure llama-cpp-python is installed to run this part.
21
- def load_llm(model_path):
22
- print(f"Attempting to load GGUF model: {model_path}...")
23
- try:
24
- from llama_cpp import Llama
25
- llm = Llama(
26
- model_path=model_path,
27
- n_gpu_layers=0,
28
- n_ctx=2048,
29
- verbose=True
30
- )
31
- print(f"Successfully loaded model: {model_path}")
32
- return llm
33
- except Exception as e:
34
- print(f"Error loading model {model_path}: {e}")
35
- # Placeholder for when models are missing (prevents crash during UI testing)
36
- return None
37
-
38
- # Load models globally
39
- llm_1b = load_llm(GGUF_MODEL_PATH_1B)
40
- llm_3b = load_llm(GGUF_MODEL_PATH_3B)
41
-
42
- # ----------------------------------------------------
43
- # 2. CORE PROCESSING FUNCTION
44
- # ----------------------------------------------------
45
- def generate_summary_and_compare(long_document, selected_model, summary_length):
46
- # 1. Select Model
47
- if "1B" in selected_model:
48
- selected_llm = llm_1b
49
- model_name_display = "Llama-3.2-1B"
50
- elif "3B" in selected_model:
51
- selected_llm = llm_3b
52
- model_name_display = "Llama-3.2-3B"
53
- else:
54
- return "Error: Invalid model selection.", ""
55
-
56
- # Check if model loaded successfully
57
- if selected_llm is None:
58
- return "Error: Model file not found or failed to load.", "Latency: N/A"
59
-
60
- # 2. Build Prompt
61
- instruction = f"Please summarize the following document and keep the summary {summary_length}. Document: \n\n{long_document}"
62
- full_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
63
-
64
- # 3. Inference
65
- start_time = time.time()
66
- max_tokens = 250 if "Detailed" in summary_length else 100
67
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  try:
69
- output = selected_llm(
70
- full_prompt,
71
- max_tokens=max_tokens,
72
- stop=["<|eot_id|>"],
73
- temperature=0.7,
74
- echo=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  )
76
- end_time = time.time()
77
- total_latency = end_time - start_time
78
- summary_output = output["choices"][0]["text"].strip()
79
- except Exception as e:
80
- total_latency = time.time() - start_time
81
- summary_output = f"Inference Error on {model_name_display}. Error: {e}"
82
-
83
- # 4. Report
84
- speed_report = f"Model: {model_name_display}\nTotal Latency: {total_latency:.2f} seconds"
85
-
86
- return summary_output, speed_report
 
 
 
87
 
88
- # ----------------------------------------------------
89
- # 3. GRADIO INTERFACE (UI IMPROVED)
90
- # ----------------------------------------------------
91
- # 使用 Soft 主题,色调简洁
92
  theme = gr.themes.Soft(
93
- primary_hue="blue",
94
- neutral_hue="slate",
95
- ).set(
96
- button_primary_background_fill="*primary_500",
97
- button_primary_background_fill_hover="*primary_600",
98
  )
99
 
100
- with gr.Blocks(title="KTH ID2223 Lab 2", theme=theme) as demo:
 
 
 
 
 
101
 
102
- # Header Section
103
  with gr.Row():
104
- gr.Markdown(
105
- """
106
- # LLM Document Summarizer
107
- Select a model and input your text below to generate a summary.
108
- """
109
- )
110
-
111
- with gr.Row(equal_height=False):
112
-
113
- # --- Left Column: Input & Controls ---
114
- with gr.Column(scale=4, variant="panel"):
115
- gr.Markdown("### Input Configuration")
116
-
117
- input_document = gr.Textbox(
118
- lines=12,
119
- label="Document Content",
120
- placeholder="Paste the text you need summarized here...",
121
- show_copy_button=True
122
  )
123
-
124
- # Grouping settings for a cleaner look
 
 
 
 
 
125
  with gr.Group():
126
- with gr.Row():
127
- model_selector = gr.Radio(
128
- ["Llama-3.2-1B (Faster)", "Llama-3.2-3B (Quality)"],
129
- label="Model Selection",
130
- value="Llama-3.2-1B (Faster)"
 
 
 
 
 
 
 
 
 
 
131
  )
132
-
133
- summary_control = gr.Radio(
134
- ["Concise (<50 words)", "Detailed (<200 words)"],
135
- label="Summary Length",
136
- value="Concise (<50 words)"
137
- )
138
-
139
- process_button = gr.Button("Generate Summary", variant="primary", size="lg")
140
-
141
- # --- Right Column: Output & Stats ---
142
- with gr.Column(scale=5):
143
- gr.Markdown("### Results")
144
 
145
- output_summary = gr.Textbox(
146
- label="Generated Summary",
147
- lines=10,
148
- interactive=False,
149
- show_copy_button=True
150
- )
151
 
152
- performance_report = gr.Textbox(
153
- label="Performance Metrics",
154
- lines=2,
155
- interactive=False
 
 
 
156
  )
157
 
158
- # Event Binding
159
- process_button.click(
160
- fn=generate_summary_and_compare,
161
- inputs=[input_document, model_selector, summary_control],
162
- outputs=[output_summary, performance_report]
 
 
 
 
 
 
 
 
 
 
 
163
  )
 
 
 
 
 
 
 
 
 
 
164
 
 
165
  if __name__ == "__main__":
166
- demo.launch()
 
 
 
1
  import os
2
+ import traceback
3
+ import time
4
+ from huggingface_hub import snapshot_download
5
+ import gradio as gr
6
 
7
+ # 尝试导入 llama_cpp,如果失败则在 UI 中提示
8
+ try:
9
+ from llama_cpp import Llama
10
+ except Exception as e:
11
+ Llama = None
12
+ Llama_import_error = e
13
 
14
+ # ---------- 配置区域 ----------
15
+ # ★★★ 请在这里修改为你的模型仓库 ★★★
16
+ MODEL_REPO = "Marcus719/Llama-3.2-3B-Instruct-FineTome-Lab2-GGUF"
17
+ # 指定只下载 q4_k_m 文件,防止下载多余文件爆盘
18
+ GGUF_FILENAME = "unsloth.Q4_K_M.gguf"
19
 
20
+ DEFAULT_N_CTX = 2048 # 上下文长度
21
+ DEFAULT_MAX_TOKENS = 256 # 默认生成长度
22
+ DEFAULT_N_THREADS = 2 # 免费 CPU 建议设为 2
23
+ # ------------------------------
24
+
25
+ def log(msg: str):
26
+ print(f"[app] {time.strftime('%Y-%m-%d %H:%M:%S')} - {msg}", flush=True)
27
+
28
+ def load_model_from_hub(repo_id: str, filename: str, n_ctx=DEFAULT_N_CTX, n_threads=DEFAULT_N_THREADS):
29
+ if Llama is None:
30
+ raise RuntimeError(f"llama-cpp-python 未安装或加载失败: {Llama_import_error}")
31
+
32
+ log(f"开始下载模型: {repo_id} / {filename} ...")
33
+
34
+ # 使用 snapshot_download 下载单个文件
35
+ # allow_patterns 确保只下载 GGUF
36
+ local_dir = snapshot_download(
37
+ repo_id=repo_id,
38
+ allow_patterns=[filename],
39
+ local_dir_use_symlinks=False # Space 中有时软链接会有问题,禁用更稳
40
+ )
41
+
42
+ # 拼接完整路径
43
+ # snapshot_download 默认会保持目录结构,或者我们直接搜寻下载目录
44
+ gguf_path = os.path.join(local_dir, filename)
45
+
46
+ # 如果直接拼接找不到,尝试搜索(容错)
47
+ if not os.path.exists(gguf_path):
48
+ for root, dirs, files in os.walk(local_dir):
49
+ if filename in files:
50
+ gguf_path = os.path.join(root, filename)
51
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ if not os.path.exists(gguf_path):
54
+ raise FileNotFoundError(f"在 {local_dir} 中找不到 {filename}")
55
+
56
+ log(f"模型路径: {gguf_path}。正在加载到内存...")
57
+
58
+ # 初始化模型
59
+ llm = Llama(model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False)
60
+ log("Llama 模型加载成功!")
61
+ return llm, gguf_path
62
+
63
+ def init_model(state):
64
+ """初始化按钮的回调函数"""
65
  try:
66
+ if state.get("llm") is not None:
67
+ return "✅ 系统就绪 (模型已加载)", state
68
+
69
+ log("收到加载请求...")
70
+ # 下载并加载
71
+ llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
72
+
73
+ # 更新状态
74
+ state["llm"] = llm
75
+ state["gguf_path"] = gguf_path
76
+
77
+ return "✅ 系统就绪", state
78
+ except Exception as exc:
79
+ tb = traceback.format_exc()
80
+ log(f"初始化错误: {exc}\n{tb}")
81
+ return f"❌ 初始化失败: {exc}", state
82
+
83
+ def generate_response(prompt: str, max_tokens: int, state):
84
+ """生成按钮的回调函数"""
85
+ try:
86
+ if not prompt or prompt.strip() == "":
87
+ return "⚠️ 请输入指令。", "⚠️ 空闲", state
88
+
89
+ # 懒加载:如果没点初始化直接点生成,尝试自动加载
90
+ if state.get("llm") is None:
91
+ try:
92
+ log("未检测到模型,尝试自动加载...")
93
+ llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
94
+ state["llm"] = llm
95
+ state["gguf_path"] = gguf_path
96
+ except Exception as e:
97
+ return f"❌ 模型加载失败: {e}", f"❌ 错误", state
98
+
99
+ llm = state.get("llm")
100
+
101
+ log(f"正在生成 (Prompt 长度={len(prompt)})...")
102
+
103
+ # 构造 Llama 3 格式的 Prompt
104
+ system_prompt = "You are a helpful AI assistant."
105
+ # 简单拼接:System + User
106
+ # 如果需要更严格的格式,可以使用 tokenizer.apply_chat_template
107
+ # 这里为了通用性使用简单的文本拼接,Llama 3 通常也能理解
108
+ full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
109
+
110
+ # 推理
111
+ output = llm(
112
+ full_prompt,
113
+ max_tokens=max_tokens,
114
+ stop=["<|eot_id|>"], # 停止符
115
+ echo=False
116
  )
117
+
118
+ text = output['choices'][0]['text']
119
+ log("生成完成。")
120
+ return text, "✅ 生成完毕", state
121
+
122
+ except Exception as exc:
123
+ tb = traceback.format_exc()
124
+ log(f"生成错误: {exc}\n{tb}")
125
+ return f"运行出错: {exc}", f"❌ 异常", state
126
+
127
+ def soft_clear(current_state):
128
+ """清除按钮:只清空文本,保留模型"""
129
+ status = "✅ 系统就绪" if current_state.get("llm") else "⚪ 未初始化"
130
+ return "", status, current_state
131
 
132
+ # ---------------- Gradio UI 构建 ----------------
133
+
134
+ # 主题设置
 
135
  theme = gr.themes.Soft(
136
+ primary_hue="indigo",
137
+ secondary_hue="slate",
138
+ neutral_hue="slate"
 
 
139
  )
140
 
141
+ # 自定义 CSS
142
+ custom_css = """
143
+ .footer-text { font-size: 0.8em; color: gray; text-align: center; }
144
+ """
145
+
146
+ with gr.Blocks(title="Llama 3.2 Lab2 Project") as demo:
147
 
148
+ # 标题头
149
  with gr.Row():
150
+ with gr.Column(scale=1):
151
+ gr.Markdown("# 🦙 Llama 3.2 (3B) Fine-Tuned Chatbot")
152
+ gr.Markdown(
153
+ f"""
154
+ **ID2223 Lab 2 Project** | Fine-tuned on **FineTome-100k**.
155
+ Running on CPU (GGUF 4-bit) | Model: `{MODEL_REPO}`
156
+ """
 
 
 
 
 
 
 
 
 
 
 
157
  )
158
+ with gr.Column(scale=0, min_width=150):
159
+ status_label = gr.Label(value="⚪ 未初始化", label="系统状态", show_label=False)
160
+
161
+ # 主体布局
162
+ with gr.Row():
163
+ # 左侧:输入与控制
164
+ with gr.Column(scale=4):
165
  with gr.Group():
166
+ prompt_in = gr.Textbox(
167
+ lines=5,
168
+ label="用户指令 (User Input)",
169
+ placeholder="例如:请解释量子力学...",
170
+ elem_id="prompt-input"
171
+ )
172
+
173
+ with gr.Accordion("⚙️ 高级参数 (Advanced)", open=False):
174
+ max_tokens = gr.Slider(
175
+ minimum=16,
176
+ maximum=1024,
177
+ step=16,
178
+ value=DEFAULT_MAX_TOKENS,
179
+ label="最大生成长度 (Max Tokens)",
180
+ info="生成的越长,CPU 耗时越久。"
181
  )
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ with gr.Row():
184
+ init_btn = gr.Button("🚀 1. 加载模型 (Load)", variant="secondary")
185
+ gen_btn = gr.Button("✨ 2. 生成回复 (Generate)", variant="primary")
 
 
 
186
 
187
+ clear_btn = gr.Button("🗑️ 清空历史 (Clear)", variant="stop")
188
+
189
+ # 右侧:输出显示
190
+ with gr.Column(scale=6):
191
+ output_txt = gr.Textbox(
192
+ label="模型回复 (Response)",
193
+ lines=15,
194
  )
195
 
196
+ # 底部说明
197
+ with gr.Row():
198
+ gr.Markdown(
199
+ "⚠️ *注意:推理在免费 CPU 上运行,速度可能较慢。首次运行时需要下载模型(约2GB),请耐心等待。*",
200
+ elem_classes=["footer-text"]
201
+ )
202
+
203
+ # 状态存储
204
+ state = gr.State({"llm": None, "gguf_path": None, "status": "Not initialized"})
205
+
206
+ # 事件绑定
207
+ init_btn.click(
208
+ fn=init_model,
209
+ inputs=state,
210
+ outputs=[status_label, state],
211
+ show_progress=True
212
  )
213
+
214
+ gen_btn.click(
215
+ fn=generate_response,
216
+ inputs=[prompt_in, max_tokens, state],
217
+ outputs=[output_txt, status_label, state],
218
+ show_progress=True
219
+ )
220
+
221
+ clear_btn.click(fn=soft_clear, inputs=[state], outputs=[prompt_in, status_label, state])
222
+ clear_btn.click(lambda: "", outputs=[output_txt])
223
 
224
+ # 启动应用
225
  if __name__ == "__main__":
226
+ demo.launch(server_name="0.0.0.0", server_port=7860)