Marcus719 commited on
Commit
c380681
·
verified ·
1 Parent(s): 3d6744f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -94
app.py CHANGED
@@ -4,22 +4,21 @@ import time
4
  from huggingface_hub import snapshot_download
5
  import gradio as gr
6
 
7
- # Attempt to import llama_cpp, handle failure gracefully in UI
8
  try:
9
  from llama_cpp import Llama
10
  except Exception as e:
11
  Llama = None
12
  Llama_import_error = e
13
 
14
- # ---------- Configuration ----------
15
- # ★★★ Replace with your model repository if needed ★★★
16
- MODEL_REPO = "Marcus719/Llama-3.2-3B-Instruct-FineTome-Lab2-GGUF"
 
17
  GGUF_FILENAME = "unsloth.Q4_K_M.gguf"
18
-
19
- DEFAULT_N_CTX = 2048 # Context Window
20
  DEFAULT_MAX_TOKENS = 256 # Default generation length
21
- DEFAULT_N_THREADS = 2 # Recommended for free CPU tier
22
-
23
  # ------------------------------
24
 
25
  def log(msg: str):
@@ -27,178 +26,190 @@ def log(msg: str):
27
 
28
  def load_model_from_hub(repo_id: str, filename: str, n_ctx=DEFAULT_N_CTX, n_threads=DEFAULT_N_THREADS):
29
  if Llama is None:
30
- raise RuntimeError(f"llama-cpp-python not installed: {Llama_import_error}")
31
 
32
- log(f"Downloading model: {repo_id} / {filename} ...")
33
 
34
- # Download specific GGUF file
 
35
  local_dir = snapshot_download(
36
  repo_id=repo_id,
37
  allow_patterns=[filename],
38
- local_dir_use_symlinks=False
39
  )
40
 
 
 
41
  gguf_path = os.path.join(local_dir, filename)
42
 
43
- # Fallback search if path joining fails
44
  if not os.path.exists(gguf_path):
45
  for root, dirs, files in os.walk(local_dir):
46
  if filename in files:
47
  gguf_path = os.path.join(root, filename)
48
  break
49
-
50
- if not os.path.exists(gguf_path):
51
- raise FileNotFoundError(f"Could not find {filename} in {local_dir}")
52
-
53
  log(f"Model path: {gguf_path}. Loading into memory...")
54
 
55
- # Initialize Llama
56
  llm = Llama(model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False)
57
  log("Llama model loaded successfully!")
58
  return llm, gguf_path
59
 
60
  def init_model(state):
61
- """Callback for Load Model button"""
62
  try:
63
  if state.get("llm") is not None:
64
- return "Ready (Model already loaded)", state
65
 
66
- log("Load request received...")
 
67
  llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
68
 
 
69
  state["llm"] = llm
70
  state["gguf_path"] = gguf_path
71
 
72
- return "Ready", state
73
  except Exception as exc:
74
  tb = traceback.format_exc()
75
- log(f"Init Error: {exc}\n{tb}")
76
- return f"Load Failed: {exc}", state
77
 
78
  def generate_response(prompt: str, max_tokens: int, state):
79
- """Callback for Generate button"""
80
  try:
81
  if not prompt or prompt.strip() == "":
82
- return "Please enter a prompt.", "Idle", state
83
 
84
- # Auto-load if not initialized
85
  if state.get("llm") is None:
86
  try:
87
- log("Model not found, auto-loading...")
88
  llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
89
  state["llm"] = llm
90
  state["gguf_path"] = gguf_path
91
  except Exception as e:
92
- return f"Model Load Error: {e}", f"Error", state
93
-
94
  llm = state.get("llm")
95
- log(f"Generating (Prompt len={len(prompt)})...")
96
 
97
- # Llama 3 Prompt Format
 
 
98
  system_prompt = "You are a helpful AI assistant."
99
- full_prompt = (
100
- f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>"
101
- f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>"
102
- f"<|start_header_id|>assistant<|end_header_id|>\n\n"
103
- )
104
 
 
105
  output = llm(
106
  full_prompt,
107
  max_tokens=max_tokens,
108
- stop=["<|eot_id|>"],
109
  echo=False
110
  )
111
 
112
  text = output['choices'][0]['text']
113
  log("Generation complete.")
114
- return text, "Done", state
115
-
116
  except Exception as exc:
117
  tb = traceback.format_exc()
118
  log(f"Generation Error: {exc}\n{tb}")
119
- return f"Runtime Error: {exc}", f"Exception", state
120
 
121
  def soft_clear(current_state):
122
- """Clear text only, keep model in memory"""
123
- status = "Ready" if current_state.get("llm") else "Not initialized"
124
- return "", status, current_state
125
 
126
- # ---------------- Gradio UI ----------------
 
 
 
 
 
127
 
128
- # Removed css and theme arguments for compatibility
129
- with gr.Blocks(title="Llama 3.2 Lab Interface") as demo:
 
 
130
 
131
  # Header
132
  with gr.Row():
133
- with gr.Column(scale=5):
134
- gr.Markdown("## Llama 3.2 (3B) Instruct - Lab Interface")
135
- gr.Markdown(f"Running **{GGUF_FILENAME}** on CPU via `llama.cpp`.")
136
- with gr.Column(scale=1, min_width=150):
137
- status_label = gr.Label(value="Not initialized", label="Status", show_label=False)
138
-
139
- gr.Markdown("---")
140
 
141
- # Main Layout
142
  with gr.Row():
143
- # LEFT: Controls
144
- with gr.Column(scale=1):
145
- prompt_in = gr.Textbox(
146
- lines=8,
147
- label="Input Prompt",
148
- placeholder="Type your question here (e.g., Explain Quantum Mechanics)...",
149
- elem_id="prompt-input"
150
- )
151
-
152
- with gr.Accordion("Settings", open=False):
153
- max_tokens = gr.Slider(
154
- minimum=16,
155
- maximum=1024,
156
- step=16,
157
- value=DEFAULT_MAX_TOKENS,
158
- label="Max New Tokens",
159
- info="Higher values take longer to generate."
160
  )
161
-
162
- with gr.Row():
163
- init_btn = gr.Button("Download & Load Model", variant="secondary")
164
- clear_btn = gr.Button("Clear", variant="stop")
165
-
166
- gen_btn = gr.Button("Generate Response", variant="primary")
167
-
168
- # RIGHT: Output
169
- with gr.Column(scale=1):
170
- # Removed 'show_copy_button=True' to fix TypeError
 
 
 
 
 
 
 
 
 
171
  output_txt = gr.Textbox(
172
- label="Model Response",
173
- lines=14,
174
- interactive=False
175
  )
176
 
177
  # Footer
178
- gr.Markdown(
179
- "**Note:** First run requires downloading ~2GB model. Inference runs on CPU and may be slow."
180
- )
 
 
181
 
182
- # State Management
183
- state = gr.State({"llm": None, "gguf_path": None})
184
 
185
- # Event Handlers
186
  init_btn.click(
187
  fn=init_model,
188
  inputs=state,
189
- outputs=[status_label, state],
190
  show_progress=True
191
  )
192
 
193
  gen_btn.click(
194
  fn=generate_response,
195
  inputs=[prompt_in, max_tokens, state],
196
- outputs=[output_txt, status_label, state],
197
  show_progress=True
198
  )
199
 
200
- clear_btn.click(fn=soft_clear, inputs=[state], outputs=[prompt_in, status_label, state])
201
  clear_btn.click(lambda: "", outputs=[output_txt])
202
 
 
203
  if __name__ == "__main__":
204
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
4
  from huggingface_hub import snapshot_download
5
  import gradio as gr
6
 
7
+ # Attempt to import llama_cpp, if failed, prompt in the UI
8
  try:
9
  from llama_cpp import Llama
10
  except Exception as e:
11
  Llama = None
12
  Llama_import_error = e
13
 
14
+ # ---------- Configuration Area ----------
15
+ # ★★★ Please change this to your model repository ★★★
16
+ MODEL_REPO = "Marcus719/Llama-3.2-1B-Compare1-Lab2-GGUF"
17
+ # Specify to download only the q4_k_m file to prevent running out of disk space
18
  GGUF_FILENAME = "unsloth.Q4_K_M.gguf"
19
+ DEFAULT_N_CTX = 2048 # Context length
 
20
  DEFAULT_MAX_TOKENS = 256 # Default generation length
21
+ DEFAULT_N_THREADS = 2 # Recommended 2 for free CPU tier
 
22
  # ------------------------------
23
 
24
  def log(msg: str):
 
26
 
27
  def load_model_from_hub(repo_id: str, filename: str, n_ctx=DEFAULT_N_CTX, n_threads=DEFAULT_N_THREADS):
28
  if Llama is None:
29
+ raise RuntimeError(f"llama-cpp-python not installed or failed to load: {Llama_import_error}")
30
 
31
+ log(f"Starting model download: {repo_id} / {filename} ...")
32
 
33
+ # Use snapshot_download to download a single file
34
+ # allow_patterns ensures only the GGUF file is downloaded
35
  local_dir = snapshot_download(
36
  repo_id=repo_id,
37
  allow_patterns=[filename],
38
+ local_dir_use_symlinks=False # Disabling symlinks for stability in Spaces
39
  )
40
 
41
+ # Construct full path
42
+ # snapshot_download usually preserves directory structure, otherwise we search
43
  gguf_path = os.path.join(local_dir, filename)
44
 
45
+ # Search for the file if direct path fails (for robustness)
46
  if not os.path.exists(gguf_path):
47
  for root, dirs, files in os.walk(local_dir):
48
  if filename in files:
49
  gguf_path = os.path.join(root, filename)
50
  break
51
+ if not os.path.exists(gguf_path):
52
+ raise FileNotFoundError(f"Could not find {filename} in {local_dir}")
53
+
 
54
  log(f"Model path: {gguf_path}. Loading into memory...")
55
 
56
+ # Initialize the model
57
  llm = Llama(model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False)
58
  log("Llama model loaded successfully!")
59
  return llm, gguf_path
60
 
61
  def init_model(state):
62
+ """Callback function for the Load button"""
63
  try:
64
  if state.get("llm") is not None:
65
+ return state
66
 
67
+ log("Received load request...")
68
+ # Download and load
69
  llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
70
 
71
+ # Update state
72
  state["llm"] = llm
73
  state["gguf_path"] = gguf_path
74
 
75
+ return state
76
  except Exception as exc:
77
  tb = traceback.format_exc()
78
+ log(f"Initialization Error: {exc}\n{tb}")
79
+ return state
80
 
81
  def generate_response(prompt: str, max_tokens: int, state):
82
+ """Callback function for the Generate button"""
83
  try:
84
  if not prompt or prompt.strip() == "":
85
+ return "Please enter an instruction.", state
86
 
87
+ # Lazy loading: attempt to auto-load if Generate is clicked without explicit initialization
88
  if state.get("llm") is None:
89
  try:
90
+ log("Model not detected, attempting auto-load...")
91
  llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
92
  state["llm"] = llm
93
  state["gguf_path"] = gguf_path
94
  except Exception as e:
95
+ return f"Model Load Failed: {e}", state
96
+
97
  llm = state.get("llm")
 
98
 
99
+ log(f"Generating (Prompt Length={len(prompt)})...")
100
+
101
+ # Construct Llama 3 format Prompt
102
  system_prompt = "You are a helpful AI assistant."
103
+ # Simple concatenation: System + User
104
+ # For strict formatting, use tokenizer.apply_chat_template
105
+ # Using simple text concatenation here for generality, Llama 3 usually understands
106
+ full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
 
107
 
108
+ # Inference
109
  output = llm(
110
  full_prompt,
111
  max_tokens=max_tokens,
112
+ stop=["<|eot_id|>"], # Stop token
113
  echo=False
114
  )
115
 
116
  text = output['choices'][0]['text']
117
  log("Generation complete.")
118
+ return text, state
 
119
  except Exception as exc:
120
  tb = traceback.format_exc()
121
  log(f"Generation Error: {exc}\n{tb}")
122
+ return f"Runtime Error: {exc}", state
123
 
124
  def soft_clear(current_state):
125
+ """Clear button: only clears text, keeps the model loaded"""
126
+ return "", current_state
 
127
 
128
+ # ---------------- Gradio UI Construction ----------------
129
+ # Theme settings
130
+ theme = gr.themes.Soft(
131
+ primary_hue="indigo",
132
+ secondary_hue="slate",
133
+ neutral_hue="slate")
134
 
135
+ # Custom CSS
136
+ custom_css = """.footer-text { font-size: 0.8em; color: gray; text-align: center; }"""
137
+
138
+ with gr.Blocks(title="Llama 3.2 Lab2 Project") as demo:
139
 
140
  # Header
141
  with gr.Row():
142
+ gr.Markdown("# Llama 3.2 (1B) Fine-Tuned Chatbot")
143
+ gr.Markdown(
144
+ f"""
145
+ **ID2223 Lab 2 Project** | Fine-tuned on **FineTome-100k**.
146
+ Running on CPU (GGUF 4-bit) | Model: `{MODEL_REPO}`
147
+ """
148
+ )
149
 
150
+ # Main layout
151
  with gr.Row():
152
+ # Left: Input and Controls
153
+ with gr.Column(scale=4):
154
+ with gr.Group():
155
+ prompt_in = gr.Textbox(
156
+ lines=5,
157
+ label="User Instruction (User Input)",
158
+ placeholder="e.g., Explain Quantum Mechanics...",
159
+ elem_id="prompt-input"
 
 
 
 
 
 
 
 
 
160
  )
161
+
162
+ with gr.Accordion("Advanced Parameters", open=False):
163
+ max_tokens = gr.Slider(
164
+ minimum=16,
165
+ maximum=1024,
166
+ step=16,
167
+ value=DEFAULT_MAX_TOKENS,
168
+ label="Max Generation Length (Max Tokens)",
169
+ info="Longer generations will take more CPU time."
170
+ )
171
+
172
+ with gr.Row():
173
+ init_btn = gr.Button("1. Load Model", variant="secondary")
174
+ gen_btn = gr.Button("2. Generate Response", variant="primary")
175
+
176
+ clear_btn = gr.Button("Clear Chat", variant="stop")
177
+
178
+ # Right: Output Display
179
+ with gr.Column(scale=6):
180
  output_txt = gr.Textbox(
181
+ label="Model Response (Response)",
182
+ lines=15,
 
183
  )
184
 
185
  # Footer
186
+ with gr.Row():
187
+ gr.Markdown(
188
+ "*Note: Inference runs on a free CPU, so speed may be slow. The model (approx. 2GB) must be downloaded on first run, please be patient.*",
189
+ elem_classes=["footer-text"]
190
+ )
191
 
192
+ # State storage
193
+ state = gr.State({"llm": None, "gguf_path": None, "status": "Not initialized"})
194
 
195
+ # Event binding
196
  init_btn.click(
197
  fn=init_model,
198
  inputs=state,
199
+ outputs=[state],
200
  show_progress=True
201
  )
202
 
203
  gen_btn.click(
204
  fn=generate_response,
205
  inputs=[prompt_in, max_tokens, state],
206
+ outputs=[output_txt, state],
207
  show_progress=True
208
  )
209
 
210
+ clear_btn.click(fn=soft_clear, inputs=[state], outputs=[prompt_in, state])
211
  clear_btn.click(lambda: "", outputs=[output_txt])
212
 
213
+ # Launch the application
214
  if __name__ == "__main__":
215
  demo.launch(server_name="0.0.0.0", server_port=7860)