Maoxt commited on
Commit
643dcb5
·
verified ·
1 Parent(s): 4345ede

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -87
app.py CHANGED
@@ -4,22 +4,21 @@ import time
4
  from huggingface_hub import snapshot_download
5
  import gradio as gr
6
 
7
- # 尝试导入 llama_cpp,如果失败则在 UI 中提示
8
  try:
9
  from llama_cpp import Llama
10
  except Exception as e:
11
  Llama = None
12
  Llama_import_error = e
13
 
14
- # ---------- 配置区域 ----------
15
- # ★★★ 请在这里修改为你的模型仓库 ★★★
16
- MODEL_REPO = "Marcus719/Llama-3.2-3B-Instruct-FineTome-Lab2-GGUF"
17
- # 指定只下载 q4_k_m 文件,防止下载多余文件爆盘
18
  GGUF_FILENAME = "unsloth.Q4_K_M.gguf"
19
-
20
- DEFAULT_N_CTX = 2048 # 上下文长度
21
- DEFAULT_MAX_TOKENS = 256 # 默认生成长度
22
- DEFAULT_N_THREADS = 2 # 免费 CPU 建议设为 2
23
  # ------------------------------
24
 
25
  def log(msg: str):
@@ -27,125 +26,119 @@ def log(msg: str):
27
 
28
  def load_model_from_hub(repo_id: str, filename: str, n_ctx=DEFAULT_N_CTX, n_threads=DEFAULT_N_THREADS):
29
  if Llama is None:
30
- raise RuntimeError(f"llama-cpp-python 未安装或加载失败: {Llama_import_error}")
31
 
32
- log(f"开始下载模型: {repo_id} / {filename} ...")
33
 
34
- # 使用 snapshot_download 下载单个文件
35
- # allow_patterns 确保只下载 GGUF
36
  local_dir = snapshot_download(
37
  repo_id=repo_id,
38
  allow_patterns=[filename],
39
- local_dir_use_symlinks=False # Space 中有时软链接会有问题,禁用更稳
40
  )
41
 
42
- # 拼接完整路径
43
- # snapshot_download 默认会保持目录结构,或者我们直接搜寻下载目录
44
  gguf_path = os.path.join(local_dir, filename)
45
 
46
- # 如果直接拼接找不到,尝试搜索(容错)
47
  if not os.path.exists(gguf_path):
48
  for root, dirs, files in os.walk(local_dir):
49
  if filename in files:
50
  gguf_path = os.path.join(root, filename)
51
  break
 
 
 
 
52
 
53
- if not os.path.exists(gguf_path):
54
- raise FileNotFoundError(f"在 {local_dir} 中找不到 {filename}")
55
-
56
- log(f"模型路径: {gguf_path}。正在加载到内存...")
57
-
58
- # 初始化模型
59
  llm = Llama(model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False)
60
- log("Llama 模型加载成功!")
61
  return llm, gguf_path
62
 
63
  def init_model(state):
64
- """初始化按钮的回调函数"""
65
  try:
66
  if state.get("llm") is not None:
67
- return "✅ 系统就绪 (模型已加载)", state
68
 
69
- log("收到加载请求...")
70
- # 下载并加载
71
  llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
72
 
73
- # 更新状态
74
  state["llm"] = llm
75
  state["gguf_path"] = gguf_path
76
 
77
- return "✅ 系统就绪", state
78
  except Exception as exc:
79
  tb = traceback.format_exc()
80
- log(f"初始化错误: {exc}\n{tb}")
81
- return f"❌ 初始化失败: {exc}", state
82
 
83
  def generate_response(prompt: str, max_tokens: int, state):
84
- """生成按钮的回调函数"""
85
  try:
86
  if not prompt or prompt.strip() == "":
87
- return "⚠️ 请输入指令。", "⚠️ 空闲", state
88
 
89
- # 懒加载:如果没点初始化直接点生成,尝试自动加载
90
  if state.get("llm") is None:
91
  try:
92
- log("未检测到模型,尝试自动加载...")
93
  llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
94
  state["llm"] = llm
95
  state["gguf_path"] = gguf_path
96
  except Exception as e:
97
- return f"❌ 模型加载失败: {e}", f"❌ 错误", state
98
-
99
  llm = state.get("llm")
100
 
101
- log(f"正在生成 (Prompt 长度={len(prompt)})...")
102
 
103
- # 构造 Llama 3 格式的 Prompt
104
  system_prompt = "You are a helpful AI assistant."
105
- # 简单拼接:System + User
106
- # 如果需要更严格的格式,可以使用 tokenizer.apply_chat_template
107
- # 这里为了通用性使用简单的文本拼接,Llama 3 通常也能理解
108
  full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
109
-
110
- # 推理
111
  output = llm(
112
  full_prompt,
113
  max_tokens=max_tokens,
114
- stop=["<|eot_id|>"], # 停止符
115
  echo=False
116
  )
117
 
118
  text = output['choices'][0]['text']
119
- log("生成完成。")
120
- return text, "✅ 生成完毕", state
121
-
122
  except Exception as exc:
123
  tb = traceback.format_exc()
124
- log(f"生成错误: {exc}\n{tb}")
125
- return f"运行出错: {exc}", f"❌ 异常", state
126
 
127
  def soft_clear(current_state):
128
- """清除按钮:只清空文本,保留模型"""
129
- status = "✅ 系统就绪" if current_state.get("llm") else "⚪ 未初始化"
130
  return "", status, current_state
131
 
132
- # ---------------- Gradio UI 构建 ----------------
133
-
134
- # 主题设置
135
  theme = gr.themes.Soft(
136
  primary_hue="indigo",
137
  secondary_hue="slate",
138
- neutral_hue="slate"
139
- )
140
 
141
- # 自定义 CSS
142
- custom_css = """
143
- .footer-text { font-size: 0.8em; color: gray; text-align: center; }
144
- """
145
 
146
- with gr.Blocks(title="Llama 3.2 Lab2 Project") as demo:
147
 
148
- # 标题头
149
  with gr.Row():
150
  with gr.Column(scale=1):
151
  gr.Markdown("# 🦙 Llama 3.2 (3B) Fine-Tuned Chatbot")
@@ -156,54 +149,54 @@ with gr.Blocks(title="Llama 3.2 Lab2 Project") as demo:
156
  """
157
  )
158
  with gr.Column(scale=0, min_width=150):
159
- status_label = gr.Label(value="⚪ 未初始化", label="系统状态", show_label=False)
160
 
161
- # 主体布局
162
  with gr.Row():
163
- # 左侧:输入与控制
164
  with gr.Column(scale=4):
165
  with gr.Group():
166
  prompt_in = gr.Textbox(
167
  lines=5,
168
- label="用户指令 (User Input)",
169
- placeholder="例如:请解释量子力学...",
170
  elem_id="prompt-input"
171
  )
172
-
173
- with gr.Accordion("⚙️ 高级参数 (Advanced)", open=False):
174
  max_tokens = gr.Slider(
175
  minimum=16,
176
  maximum=1024,
177
  step=16,
178
  value=DEFAULT_MAX_TOKENS,
179
- label="最大生成长度 (Max Tokens)",
180
- info="生成的越长,CPU 耗时越久。"
181
  )
182
-
183
- with gr.Row():
184
- init_btn = gr.Button("🚀 1. 加载模型 (Load)", variant="secondary")
185
- gen_btn = gr.Button("✨ 2. 生成回复 (Generate)", variant="primary")
186
-
187
- clear_btn = gr.Button("🗑️ 清空历史 (Clear)", variant="stop")
188
-
189
- # 右侧:输出显示
190
  with gr.Column(scale=6):
191
  output_txt = gr.Textbox(
192
- label="模型回复 (Response)",
193
  lines=15,
194
  )
195
 
196
- # 底部说明
197
  with gr.Row():
198
  gr.Markdown(
199
- "⚠️ *注意:推理在免费 CPU 上运行,速度可能较慢。首次运行时需要下载模型(约2GB),请耐心等待。*",
200
  elem_classes=["footer-text"]
201
  )
202
 
203
- # 状态存储
204
  state = gr.State({"llm": None, "gguf_path": None, "status": "Not initialized"})
205
 
206
- # 事件绑定
207
  init_btn.click(
208
  fn=init_model,
209
  inputs=state,
@@ -221,6 +214,6 @@ with gr.Blocks(title="Llama 3.2 Lab2 Project") as demo:
221
  clear_btn.click(fn=soft_clear, inputs=[state], outputs=[prompt_in, status_label, state])
222
  clear_btn.click(lambda: "", outputs=[output_txt])
223
 
224
- # 启动应用
225
  if __name__ == "__main__":
226
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
4
  from huggingface_hub import snapshot_download
5
  import gradio as gr
6
 
7
+ # Attempt to import llama_cpp, if failed, prompt in the UI
8
  try:
9
  from llama_cpp import Llama
10
  except Exception as e:
11
  Llama = None
12
  Llama_import_error = e
13
 
14
+ # ---------- Configuration Area ----------
15
+ # ★★★ Please change this to your model repository ★★★
16
+ MODEL_REPO = "Marcus719/Llama-3.2-3B-Instruct-FineTome-Lab2-GGUF"
17
+ # Specify to download only the q4_k_m file to prevent running out of disk space
18
  GGUF_FILENAME = "unsloth.Q4_K_M.gguf"
19
+ DEFAULT_N_CTX = 2048 # Context length
20
+ DEFAULT_MAX_TOKENS = 256 # Default generation length
21
+ DEFAULT_N_THREADS = 2 # Recommended 2 for free CPU tier
 
22
  # ------------------------------
23
 
24
  def log(msg: str):
 
26
 
27
  def load_model_from_hub(repo_id: str, filename: str, n_ctx=DEFAULT_N_CTX, n_threads=DEFAULT_N_THREADS):
28
  if Llama is None:
29
+ raise RuntimeError(f"llama-cpp-python not installed or failed to load: {Llama_import_error}")
30
 
31
+ log(f"Starting model download: {repo_id} / {filename} ...")
32
 
33
+ # Use snapshot_download to download a single file
34
+ # allow_patterns ensures only the GGUF file is downloaded
35
  local_dir = snapshot_download(
36
  repo_id=repo_id,
37
  allow_patterns=[filename],
38
+ local_dir_use_symlinks=False # Disabling symlinks for stability in Spaces
39
  )
40
 
41
+ # Construct full path
42
+ # snapshot_download usually preserves directory structure, otherwise we search
43
  gguf_path = os.path.join(local_dir, filename)
44
 
45
+ # Search for the file if direct path fails (for robustness)
46
  if not os.path.exists(gguf_path):
47
  for root, dirs, files in os.walk(local_dir):
48
  if filename in files:
49
  gguf_path = os.path.join(root, filename)
50
  break
51
+ if not os.path.exists(gguf_path):
52
+ raise FileNotFoundError(f"Could not find {filename} in {local_dir}")
53
+
54
+ log(f"Model path: {gguf_path}. Loading into memory...")
55
 
56
+ # Initialize the model
 
 
 
 
 
57
  llm = Llama(model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False)
58
+ log("Llama model loaded successfully!")
59
  return llm, gguf_path
60
 
61
  def init_model(state):
62
+ """Callback function for the Load button"""
63
  try:
64
  if state.get("llm") is not None:
65
+ return "✅ System Ready (Model Loaded)", state
66
 
67
+ log("Received load request...")
68
+ # Download and load
69
  llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
70
 
71
+ # Update state
72
  state["llm"] = llm
73
  state["gguf_path"] = gguf_path
74
 
75
+ return "✅ System Ready", state
76
  except Exception as exc:
77
  tb = traceback.format_exc()
78
+ log(f"Initialization Error: {exc}\n{tb}")
79
+ return f"❌ Initialization Failed: {exc}", state
80
 
81
  def generate_response(prompt: str, max_tokens: int, state):
82
+ """Callback function for the Generate button"""
83
  try:
84
  if not prompt or prompt.strip() == "":
85
+ return "⚠️ Please enter an instruction.", "⚠️ Idle", state
86
 
87
+ # Lazy loading: attempt to auto-load if Generate is clicked without explicit initialization
88
  if state.get("llm") is None:
89
  try:
90
+ log("Model not detected, attempting auto-load...")
91
  llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
92
  state["llm"] = llm
93
  state["gguf_path"] = gguf_path
94
  except Exception as e:
95
+ return f"❌ Model Load Failed: {e}", f"❌ Error", state
96
+
97
  llm = state.get("llm")
98
 
99
+ log(f"Generating (Prompt Length={len(prompt)})...")
100
 
101
+ # Construct Llama 3 format Prompt
102
  system_prompt = "You are a helpful AI assistant."
103
+ # Simple concatenation: System + User
104
+ # For strict formatting, use tokenizer.apply_chat_template
105
+ # Using simple text concatenation here for generality, Llama 3 usually understands
106
  full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
107
+
108
+ # Inference
109
  output = llm(
110
  full_prompt,
111
  max_tokens=max_tokens,
112
+ stop=["<|eot_id|>"], # Stop token
113
  echo=False
114
  )
115
 
116
  text = output['choices'][0]['text']
117
+ log("Generation complete.")
118
+ return text, "✅ Generation Complete", state
 
119
  except Exception as exc:
120
  tb = traceback.format_exc()
121
+ log(f"Generation Error: {exc}\n{tb}")
122
+ return f"Runtime Error: {exc}", f"❌ Exception", state
123
 
124
  def soft_clear(current_state):
125
+ """Clear button: only clears text, keeps the model loaded"""
126
+ status = "✅ System Ready" if current_state.get("llm") else "⚪ Not Initialized"
127
  return "", status, current_state
128
 
129
+ # ---------------- Gradio UI Construction ----------------
130
+ # Theme settings
 
131
  theme = gr.themes.Soft(
132
  primary_hue="indigo",
133
  secondary_hue="slate",
134
+ neutral_hue="slate")
 
135
 
136
+ # Custom CSS
137
+ custom_css = """.footer-text { font-size: 0.8em; color: gray; text-align: center; }"""
 
 
138
 
139
+ with gr.Blocks(title="Llama 3.2 Lab2 Project", css=custom_css, theme=theme) as demo:
140
 
141
+ # Header
142
  with gr.Row():
143
  with gr.Column(scale=1):
144
  gr.Markdown("# 🦙 Llama 3.2 (3B) Fine-Tuned Chatbot")
 
149
  """
150
  )
151
  with gr.Column(scale=0, min_width=150):
152
+ status_label = gr.Label(value="⚪ Not Initialized", label="System Status", show_label=False)
153
 
154
+ # Main layout
155
  with gr.Row():
156
+ # Left: Input and Controls
157
  with gr.Column(scale=4):
158
  with gr.Group():
159
  prompt_in = gr.Textbox(
160
  lines=5,
161
+ label="User Instruction (User Input)",
162
+ placeholder="e.g., Explain Quantum Mechanics...",
163
  elem_id="prompt-input"
164
  )
165
+
166
+ with gr.Accordion("⚙️ Advanced Parameters", open=False):
167
  max_tokens = gr.Slider(
168
  minimum=16,
169
  maximum=1024,
170
  step=16,
171
  value=DEFAULT_MAX_TOKENS,
172
+ label="Max Generation Length (Max Tokens)",
173
+ info="Longer generations will take more CPU time."
174
  )
175
+
176
+ with gr.Row():
177
+ init_btn = gr.Button("🚀 1. Load Model", variant="secondary")
178
+ gen_btn = gr.Button("✨ 2. Generate Response", variant="primary")
179
+
180
+ clear_btn = gr.Button("🗑️ Clear Chat", variant="stop")
181
+
182
+ # Right: Output Display
183
  with gr.Column(scale=6):
184
  output_txt = gr.Textbox(
185
+ label="Model Response (Response)",
186
  lines=15,
187
  )
188
 
189
+ # Footer
190
  with gr.Row():
191
  gr.Markdown(
192
+ "⚠️ *Note: Inference runs on a free CPU, so speed may be slow. The model (approx. 2GB) must be downloaded on first run, please be patient.*",
193
  elem_classes=["footer-text"]
194
  )
195
 
196
+ # State storage
197
  state = gr.State({"llm": None, "gguf_path": None, "status": "Not initialized"})
198
 
199
+ # Event binding
200
  init_btn.click(
201
  fn=init_model,
202
  inputs=state,
 
214
  clear_btn.click(fn=soft_clear, inputs=[state], outputs=[prompt_in, status_label, state])
215
  clear_btn.click(lambda: "", outputs=[output_txt])
216
 
217
+ # Launch the application
218
  if __name__ == "__main__":
219
  demo.launch(server_name="0.0.0.0", server_port=7860)