Marcus719 commited on
Commit
078bd3c
·
verified ·
1 Parent(s): 4345ede

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -203
app.py CHANGED
@@ -1,226 +1,166 @@
1
- import os
2
- import traceback
3
- import time
4
- from huggingface_hub import snapshot_download
5
  import gradio as gr
 
 
 
 
6
 
7
- # 尝试导入 llama_cpp,如果失败则在 UI 中提示
8
- try:
9
- from llama_cpp import Llama
10
- except Exception as e:
11
- Llama = None
12
- Llama_import_error = e
13
-
14
- # ---------- 配置区域 ----------
15
- # ★★★ 请在这里修改为你的模型仓库 ★★★
16
- MODEL_REPO = "Marcus719/Llama-3.2-3B-Instruct-FineTome-Lab2-GGUF"
17
- # 指定只下载 q4_k_m 文件,防止下载多余文件爆盘
18
- GGUF_FILENAME = "unsloth.Q4_K_M.gguf"
19
-
20
- DEFAULT_N_CTX = 2048 # 上下文长度
21
- DEFAULT_MAX_TOKENS = 256 # 默认生成长度
22
- DEFAULT_N_THREADS = 2 # 免费 CPU 建议设为 2
23
- # ------------------------------
24
-
25
- def log(msg: str):
26
- print(f"[app] {time.strftime('%Y-%m-%d %H:%M:%S')} - {msg}", flush=True)
27
-
28
- def load_model_from_hub(repo_id: str, filename: str, n_ctx=DEFAULT_N_CTX, n_threads=DEFAULT_N_THREADS):
29
- if Llama is None:
30
- raise RuntimeError(f"llama-cpp-python 未安装或加载失败: {Llama_import_error}")
31
-
32
- log(f"开始下载模型: {repo_id} / {filename} ...")
33
-
34
- # 使用 snapshot_download 下载单个文件
35
- # allow_patterns 确保只下载 GGUF
36
- local_dir = snapshot_download(
37
- repo_id=repo_id,
38
- allow_patterns=[filename],
39
- local_dir_use_symlinks=False # 在 Space 中有时软链接会有问题,禁用更稳
40
- )
41
-
42
- # 拼接完整路径
43
- # snapshot_download 默认会保持目录结构,或者我们直接搜寻下载目录
44
- gguf_path = os.path.join(local_dir, filename)
45
-
46
- # 如果直接拼接找不到,尝试搜索(容错)
47
- if not os.path.exists(gguf_path):
48
- for root, dirs, files in os.walk(local_dir):
49
- if filename in files:
50
- gguf_path = os.path.join(root, filename)
51
- break
52
-
53
- if not os.path.exists(gguf_path):
54
- raise FileNotFoundError(f"在 {local_dir} 中找不到 {filename}")
55
 
56
- log(f"模型路径: {gguf_path}。正在加载到内存...")
57
-
58
- # 初始化模型
59
- llm = Llama(model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False)
60
- log("Llama 模型加载成功!")
61
- return llm, gguf_path
62
 
63
- def init_model(state):
64
- """初始化按钮的回调函数"""
 
 
 
 
 
65
  try:
66
- if state.get("llm") is not None:
67
- return "✅ 系统就绪 (模型已加载)", state
68
-
69
- log("收到加载请求...")
70
- # 下载并加载
71
- llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
72
-
73
- # 更新状态
74
- state["llm"] = llm
75
- state["gguf_path"] = gguf_path
76
-
77
- return "✅ 系统就绪", state
78
- except Exception as exc:
79
- tb = traceback.format_exc()
80
- log(f"初始化错误: {exc}\n{tb}")
81
- return f"❌ 初始化失败: {exc}", state
82
-
83
- def generate_response(prompt: str, max_tokens: int, state):
84
- """生成按钮的回调函数"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  try:
86
- if not prompt or prompt.strip() == "":
87
- return "⚠️ 请输入指令。", "⚠️ 空闲", state
88
-
89
- # 懒加载:如果没点初始化直接点生成,尝试自动加载
90
- if state.get("llm") is None:
91
- try:
92
- log("未检测到模型,尝试自动加载...")
93
- llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
94
- state["llm"] = llm
95
- state["gguf_path"] = gguf_path
96
- except Exception as e:
97
- return f"❌ 模型加载失败: {e}", f"❌ 错误", state
98
-
99
- llm = state.get("llm")
100
-
101
- log(f"正在生成 (Prompt 长度={len(prompt)})...")
102
-
103
- # 构造 Llama 3 格式的 Prompt
104
- system_prompt = "You are a helpful AI assistant."
105
- # 简单拼接:System + User
106
- # 如果需要更严格的格式,可以使用 tokenizer.apply_chat_template
107
- # 这里为了通用性使用简单的文本拼接,Llama 3 通常也能理解
108
- full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
109
-
110
- # 推理
111
- output = llm(
112
- full_prompt,
113
- max_tokens=max_tokens,
114
- stop=["<|eot_id|>"], # 停止符
115
- echo=False
116
  )
117
-
118
- text = output['choices'][0]['text']
119
- log("生成完成。")
120
- return text, "✅ 生成完毕", state
121
-
122
- except Exception as exc:
123
- tb = traceback.format_exc()
124
- log(f"生成错误: {exc}\n{tb}")
125
- return f"运行出错: {exc}", f"❌ 异常", state
126
-
127
- def soft_clear(current_state):
128
- """清除按钮:只清空文本,保留模型"""
129
- status = "✅ 系统就绪" if current_state.get("llm") else "⚪ 未初始化"
130
- return "", status, current_state
131
-
132
- # ---------------- Gradio UI 构建 ----------------
133
 
134
- # 主题设置
 
 
 
135
  theme = gr.themes.Soft(
136
- primary_hue="indigo",
137
- secondary_hue="slate",
138
- neutral_hue="slate"
 
 
139
  )
140
 
141
- # 自定义 CSS
142
- custom_css = """
143
- .footer-text { font-size: 0.8em; color: gray; text-align: center; }
144
- """
145
-
146
- with gr.Blocks(title="Llama 3.2 Lab2 Project") as demo:
147
 
148
- # 标题头
149
  with gr.Row():
150
- with gr.Column(scale=1):
151
- gr.Markdown("# 🦙 Llama 3.2 (3B) Fine-Tuned Chatbot")
152
- gr.Markdown(
153
- f"""
154
- **ID2223 Lab 2 Project** | Fine-tuned on **FineTome-100k**.
155
- Running on CPU (GGUF 4-bit) | Model: `{MODEL_REPO}`
156
- """
157
- )
158
- with gr.Column(scale=0, min_width=150):
159
- status_label = gr.Label(value="⚪ 未初始化", label="系统状态", show_label=False)
160
 
161
- # 主体布局
162
- with gr.Row():
163
- # 左侧:输入与控制
164
- with gr.Column(scale=4):
 
 
 
 
 
 
 
 
 
 
165
  with gr.Group():
166
- prompt_in = gr.Textbox(
167
- lines=5,
168
- label="用户指令 (User Input)",
169
- placeholder="例如:请解释量子力学...",
170
- elem_id="prompt-input"
171
- )
172
-
173
- with gr.Accordion("⚙️ 高级参数 (Advanced)", open=False):
174
- max_tokens = gr.Slider(
175
- minimum=16,
176
- maximum=1024,
177
- step=16,
178
- value=DEFAULT_MAX_TOKENS,
179
- label="最大生成长度 (Max Tokens)",
180
- info="生成的越长,CPU 耗时越久。"
181
  )
182
 
183
- with gr.Row():
184
- init_btn = gr.Button("🚀 1. 加载模型 (Load)", variant="secondary")
185
- gen_btn = gr.Button("✨ 2. 生成回复 (Generate)", variant="primary")
186
-
187
- clear_btn = gr.Button("🗑️ 清空历史 (Clear)", variant="stop")
188
 
189
- # 右侧:输出显示
190
- with gr.Column(scale=6):
191
- output_txt = gr.Textbox(
192
- label="模型回复 (Response)",
193
- lines=15,
 
 
 
 
 
 
 
 
 
 
194
  )
195
 
196
- # 底部说明
197
- with gr.Row():
198
- gr.Markdown(
199
- "⚠️ *注意:推理在免费 CPU 上运行,速度可能较慢。首次运行时需要下载模型(约2GB),请耐心等待。*",
200
- elem_classes=["footer-text"]
201
- )
202
-
203
- # 状态存储
204
- state = gr.State({"llm": None, "gguf_path": None, "status": "Not initialized"})
205
-
206
- # 事件绑定
207
- init_btn.click(
208
- fn=init_model,
209
- inputs=state,
210
- outputs=[status_label, state],
211
- show_progress=True
212
  )
213
-
214
- gen_btn.click(
215
- fn=generate_response,
216
- inputs=[prompt_in, max_tokens, state],
217
- outputs=[output_txt, status_label, state],
218
- show_progress=True
219
- )
220
-
221
- clear_btn.click(fn=soft_clear, inputs=[state], outputs=[prompt_in, status_label, state])
222
- clear_btn.click(lambda: "", outputs=[output_txt])
223
 
224
- # 启动应用
225
  if __name__ == "__main__":
226
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
  import gradio as gr
2
+ import time
3
+ import os
4
+ # from llama_cpp import Llama # Uncomment if running locally with the library installed
5
+ import numpy as np
6
 
7
+ # --- CONFIGURATION ---
8
+ GGUF_MODEL_PATH_1B = "./llama-3.2-1b-summary-q4_k_m.gguf"
9
+ GGUF_MODEL_PATH_3B = "./llama-3.2-3b-summary-q4_k_m.gguf"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ SYSTEM_PROMPT = (
12
+ "You are an expert summarization bot. Your task is to provide a comprehensive "
13
+ "and concise summary of the user's document based on the requested length."
14
+ )
 
 
15
 
16
+ # ----------------------------------------------------
17
+ # 1. MODEL LOADING FUNCTION
18
+ # ----------------------------------------------------
19
+ # Note: For demonstration purposes, I am keeping your logic structure.
20
+ # Ensure llama-cpp-python is installed to run this part.
21
+ def load_llm(model_path):
22
+ print(f"Attempting to load GGUF model: {model_path}...")
23
  try:
24
+ from llama_cpp import Llama
25
+ llm = Llama(
26
+ model_path=model_path,
27
+ n_gpu_layers=0,
28
+ n_ctx=2048,
29
+ verbose=True
30
+ )
31
+ print(f"Successfully loaded model: {model_path}")
32
+ return llm
33
+ except Exception as e:
34
+ print(f"Error loading model {model_path}: {e}")
35
+ # Placeholder for when models are missing (prevents crash during UI testing)
36
+ return None
37
+
38
+ # Load models globally
39
+ llm_1b = load_llm(GGUF_MODEL_PATH_1B)
40
+ llm_3b = load_llm(GGUF_MODEL_PATH_3B)
41
+
42
+ # ----------------------------------------------------
43
+ # 2. CORE PROCESSING FUNCTION
44
+ # ----------------------------------------------------
45
+ def generate_summary_and_compare(long_document, selected_model, summary_length):
46
+ # 1. Select Model
47
+ if "1B" in selected_model:
48
+ selected_llm = llm_1b
49
+ model_name_display = "Llama-3.2-1B"
50
+ elif "3B" in selected_model:
51
+ selected_llm = llm_3b
52
+ model_name_display = "Llama-3.2-3B"
53
+ else:
54
+ return "Error: Invalid model selection.", ""
55
+
56
+ # Check if model loaded successfully
57
+ if selected_llm is None:
58
+ return "Error: Model file not found or failed to load.", "Latency: N/A"
59
+
60
+ # 2. Build Prompt
61
+ instruction = f"Please summarize the following document and keep the summary {summary_length}. Document: \n\n{long_document}"
62
+ full_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
63
+
64
+ # 3. Inference
65
+ start_time = time.time()
66
+ max_tokens = 250 if "Detailed" in summary_length else 100
67
+
68
  try:
69
+ output = selected_llm(
70
+ full_prompt,
71
+ max_tokens=max_tokens,
72
+ stop=["<|eot_id|>"],
73
+ temperature=0.7,
74
+ echo=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  )
76
+ end_time = time.time()
77
+ total_latency = end_time - start_time
78
+ summary_output = output["choices"][0]["text"].strip()
79
+ except Exception as e:
80
+ total_latency = time.time() - start_time
81
+ summary_output = f"Inference Error on {model_name_display}. Error: {e}"
82
+
83
+ # 4. Report
84
+ speed_report = f"Model: {model_name_display}\nTotal Latency: {total_latency:.2f} seconds"
85
+
86
+ return summary_output, speed_report
 
 
 
 
 
87
 
88
+ # ----------------------------------------------------
89
+ # 3. GRADIO INTERFACE (UI IMPROVED)
90
+ # ----------------------------------------------------
91
+ # 使用 Soft 主题,色调简洁
92
  theme = gr.themes.Soft(
93
+ primary_hue="blue",
94
+ neutral_hue="slate",
95
+ ).set(
96
+ button_primary_background_fill="*primary_500",
97
+ button_primary_background_fill_hover="*primary_600",
98
  )
99
 
100
+ with gr.Blocks(title="KTH ID2223 Lab 2", theme=theme) as demo:
 
 
 
 
 
101
 
102
+ # Header Section
103
  with gr.Row():
104
+ gr.Markdown(
105
+ """
106
+ # LLM Document Summarizer
107
+ Select a model and input your text below to generate a summary.
108
+ """
109
+ )
 
 
 
 
110
 
111
+ with gr.Row(equal_height=False):
112
+
113
+ # --- Left Column: Input & Controls ---
114
+ with gr.Column(scale=4, variant="panel"):
115
+ gr.Markdown("### Input Configuration")
116
+
117
+ input_document = gr.Textbox(
118
+ lines=12,
119
+ label="Document Content",
120
+ placeholder="Paste the text you need summarized here...",
121
+ show_copy_button=True
122
+ )
123
+
124
+ # Grouping settings for a cleaner look
125
  with gr.Group():
126
+ with gr.Row():
127
+ model_selector = gr.Radio(
128
+ ["Llama-3.2-1B (Faster)", "Llama-3.2-3B (Quality)"],
129
+ label="Model Selection",
130
+ value="Llama-3.2-1B (Faster)"
131
+ )
132
+
133
+ summary_control = gr.Radio(
134
+ ["Concise (<50 words)", "Detailed (<200 words)"],
135
+ label="Summary Length",
136
+ value="Concise (<50 words)"
 
 
 
 
137
  )
138
 
139
+ process_button = gr.Button("Generate Summary", variant="primary", size="lg")
 
 
 
 
140
 
141
+ # --- Right Column: Output & Stats ---
142
+ with gr.Column(scale=5):
143
+ gr.Markdown("### Results")
144
+
145
+ output_summary = gr.Textbox(
146
+ label="Generated Summary",
147
+ lines=10,
148
+ interactive=False,
149
+ show_copy_button=True
150
+ )
151
+
152
+ performance_report = gr.Textbox(
153
+ label="Performance Metrics",
154
+ lines=2,
155
+ interactive=False
156
  )
157
 
158
+ # Event Binding
159
+ process_button.click(
160
+ fn=generate_summary_and_compare,
161
+ inputs=[input_document, model_selector, summary_control],
162
+ outputs=[output_summary, performance_report]
 
 
 
 
 
 
 
 
 
 
 
163
  )
 
 
 
 
 
 
 
 
 
 
164
 
 
165
  if __name__ == "__main__":
166
+ demo.launch()