akhaliq HF Staff commited on
Commit
1182537
·
verified ·
1 Parent(s): 0482642

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -56
app.py CHANGED
@@ -5,6 +5,7 @@ from huggingface_hub import login
5
  import os
6
  from typing import List, Dict, Any
7
  import time
 
8
 
9
  # Configuration
10
  MODEL_ID = "facebook/MobileLLM-Pro"
@@ -27,9 +28,11 @@ class MobileLLMChat:
27
  self.tokenizer = None
28
  self.device = None
29
  self.model_loaded = False
 
 
30
 
31
  def load_model(self, version="instruct"):
32
- """Load the MobileLLM-Pro model and tokenizer"""
33
  try:
34
  print(f"Loading MobileLLM-Pro ({version})...")
35
 
@@ -40,23 +43,19 @@ class MobileLLMChat:
40
  subfolder=version
41
  )
42
 
43
- # Load model
44
  self.model = AutoModelForCausalLM.from_pretrained(
45
  MODEL_ID,
46
  trust_remote_code=True,
47
  subfolder=version,
48
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
49
- device_map="auto" if torch.cuda.is_available() else None
50
  )
51
 
52
- # Set device
53
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
- if not torch.cuda.is_available():
55
- self.model.to(self.device)
56
-
57
  self.model.eval()
58
  self.model_loaded = True
59
- print(f"Model loaded successfully on {self.device}")
60
  return True
61
 
62
  except Exception as e:
@@ -73,14 +72,19 @@ class MobileLLMChat:
73
 
74
  return messages
75
 
 
76
  def generate_response(self, user_input: str, history: List[Dict[str, str]],
77
  system_prompt: str, temperature: float = 0.7,
78
  max_new_tokens: int = MAX_NEW_TOKENS) -> str:
79
- """Generate a response from the model"""
80
  if not self.model_loaded:
81
- return "Model not loaded. Please try loading the model first."
82
 
83
  try:
 
 
 
 
84
  # Add user message to history
85
  history.append({"role": "user", "content": user_input})
86
 
@@ -125,19 +129,28 @@ class MobileLLMChat:
125
  # Add assistant response to history
126
  history.append({"role": "assistant", "content": response})
127
 
 
 
 
 
128
  return response
129
 
130
  except Exception as e:
131
  return f"Error generating response: {str(e)}"
132
 
 
133
  def generate_stream(self, user_input: str, history: List[Dict[str, str]],
134
  system_prompt: str, temperature: float = 0.7):
135
- """Generate a streaming response from the model"""
136
  if not self.model_loaded:
137
- yield "Model not loaded. Please try loading the model first."
138
  return
139
 
140
  try:
 
 
 
 
141
  # Add user message to history
142
  history.append({"role": "user", "content": user_input})
143
 
@@ -189,28 +202,25 @@ class MobileLLMChat:
189
  # Add final response to history
190
  history.append({"role": "assistant", "content": response})
191
 
 
 
 
 
192
  except Exception as e:
193
  yield f"Error generating response: {str(e)}"
194
 
195
- # Initialize chat model
 
196
  chat_model = MobileLLMChat()
197
 
198
- def load_model_button(version):
199
- """Load the model when button is clicked"""
200
- success = chat_model.load_model(version)
201
- if success:
202
- return gr.update(visible=False), gr.update(visible=True), gr.update(value="Model loaded successfully!")
203
- else:
204
- return gr.update(visible=True), gr.update(visible=False), gr.update(value="Failed to load model. Please check the logs.")
205
-
206
  def clear_chat():
207
  """Clear the chat history"""
208
  return [], []
209
 
210
- def chat_fn(message, history, system_prompt, temperature, model_version):
211
  """Main chat function"""
212
  if not chat_model.model_loaded:
213
- return "Please load the model first using the button above."
214
 
215
  # Convert history format
216
  formatted_history = []
@@ -224,10 +234,10 @@ def chat_fn(message, history, system_prompt, temperature, model_version):
224
 
225
  return response
226
 
227
- def chat_stream_fn(message, history, system_prompt, temperature, model_version):
228
  """Streaming chat function"""
229
  if not chat_model.model_loaded:
230
- yield "Please load the model first using the button above."
231
  return
232
 
233
  # Convert history format
@@ -275,23 +285,14 @@ with gr.Blocks(
275
  </div>
276
  """)
277
 
278
- # Model loading section
279
  with gr.Row():
280
- with gr.Column(scale=1):
281
- model_version = gr.Dropdown(
282
- choices=["instruct", "base"],
283
- value="instruct",
284
- label="Model Version",
285
- info="Choose between instruct (chat) or base model"
286
- )
287
- load_btn = gr.Button("🚀 Load Model", variant="primary", size="lg")
288
-
289
- with gr.Column(scale=2):
290
- model_status = gr.Textbox(
291
- label="Model Status",
292
- value="Model not loaded",
293
- interactive=False
294
- )
295
 
296
  # Configuration section
297
  with gr.Accordion("⚙️ Configuration", open=False):
@@ -337,29 +338,22 @@ with gr.Blocks(
337
  submit_btn = gr.Button("Send", variant="primary", scale=1)
338
  clear_btn = gr.Button("Clear", scale=0)
339
 
340
- # Event handlers
341
- load_btn.click(
342
- load_model_button,
343
- inputs=[model_version],
344
- outputs=[load_btn, model_status, model_status]
345
- )
346
-
347
  # Handle chat submission
348
- def handle_chat(message, history, system_prompt, temperature, model_version, streaming):
349
  if streaming:
350
- return chat_stream_fn(message, history, system_prompt, temperature, model_version)
351
  else:
352
- return chat_fn(message, history, system_prompt, temperature, model_version)
353
 
354
  msg.submit(
355
  handle_chat,
356
- inputs=[msg, chatbot, system_prompt, temperature, model_version, streaming],
357
  outputs=[chatbot]
358
  )
359
 
360
  submit_btn.click(
361
  handle_chat,
362
- inputs=[msg, chatbot, system_prompt, temperature, model_version, streaming],
363
  outputs=[chatbot]
364
  )
365
 
@@ -384,7 +378,7 @@ with gr.Blocks(
384
  # Footer
385
  gr.HTML("""
386
  <div style="text-align: center; margin-top: 20px; color: #666;">
387
- <p>⚠️ Note: This model requires significant computational resources. Loading may take a few minutes.</p>
388
  <p>Model: <a href="https://huggingface.co/facebook/MobileLLM-Pro" target="_blank">facebook/MobileLLM-Pro</a></p>
389
  </div>
390
  """)
 
5
  import os
6
  from typing import List, Dict, Any
7
  import time
8
+ import spaces
9
 
10
  # Configuration
11
  MODEL_ID = "facebook/MobileLLM-Pro"
 
28
  self.tokenizer = None
29
  self.device = None
30
  self.model_loaded = False
31
+ # Load model on initialization for shared app
32
+ self.load_model()
33
 
34
  def load_model(self, version="instruct"):
35
+ """Load the MobileLLM-Pro model and tokenizer - runs once on CPU/system memory"""
36
  try:
37
  print(f"Loading MobileLLM-Pro ({version})...")
38
 
 
43
  subfolder=version
44
  )
45
 
46
+ # Load model to CPU first for shared app
47
  self.model = AutoModelForCausalLM.from_pretrained(
48
  MODEL_ID,
49
  trust_remote_code=True,
50
  subfolder=version,
51
+ torch_dtype=torch.float16,
52
+ low_cpu_mem_usage=True
53
  )
54
 
55
+ # Model will be moved to GPU during inference
 
 
 
 
56
  self.model.eval()
57
  self.model_loaded = True
58
+ print(f"Model loaded successfully in system memory")
59
  return True
60
 
61
  except Exception as e:
 
72
 
73
  return messages
74
 
75
+ @spaces.GPU(duration=120)
76
  def generate_response(self, user_input: str, history: List[Dict[str, str]],
77
  system_prompt: str, temperature: float = 0.7,
78
  max_new_tokens: int = MAX_NEW_TOKENS) -> str:
79
+ """Generate a response from the model - GPU allocated only during inference"""
80
  if not self.model_loaded:
81
+ return "Model not loaded. Please try reloading the space."
82
 
83
  try:
84
+ # Move model to GPU for inference
85
+ self.device = torch.device("cuda")
86
+ self.model.to(self.device)
87
+
88
  # Add user message to history
89
  history.append({"role": "user", "content": user_input})
90
 
 
129
  # Add assistant response to history
130
  history.append({"role": "assistant", "content": response})
131
 
132
+ # Move model back to CPU after inference to free GPU
133
+ self.model.to("cpu")
134
+ torch.cuda.empty_cache()
135
+
136
  return response
137
 
138
  except Exception as e:
139
  return f"Error generating response: {str(e)}"
140
 
141
+ @spaces.GPU(duration=120)
142
  def generate_stream(self, user_input: str, history: List[Dict[str, str]],
143
  system_prompt: str, temperature: float = 0.7):
144
+ """Generate a streaming response from the model - GPU allocated only during inference"""
145
  if not self.model_loaded:
146
+ yield "Model not loaded. Please try reloading the space."
147
  return
148
 
149
  try:
150
+ # Move model to GPU for inference
151
+ self.device = torch.device("cuda")
152
+ self.model.to(self.device)
153
+
154
  # Add user message to history
155
  history.append({"role": "user", "content": user_input})
156
 
 
202
  # Add final response to history
203
  history.append({"role": "assistant", "content": response})
204
 
205
+ # Move model back to CPU after inference to free GPU
206
+ self.model.to("cpu")
207
+ torch.cuda.empty_cache()
208
+
209
  except Exception as e:
210
  yield f"Error generating response: {str(e)}"
211
 
212
+ # Initialize chat model (loads model once on startup)
213
+ print("Initializing MobileLLM-Pro model...")
214
  chat_model = MobileLLMChat()
215
 
 
 
 
 
 
 
 
 
216
  def clear_chat():
217
  """Clear the chat history"""
218
  return [], []
219
 
220
+ def chat_fn(message, history, system_prompt, temperature):
221
  """Main chat function"""
222
  if not chat_model.model_loaded:
223
+ return "Please wait for the model to load or reload the space."
224
 
225
  # Convert history format
226
  formatted_history = []
 
234
 
235
  return response
236
 
237
+ def chat_stream_fn(message, history, system_prompt, temperature):
238
  """Streaming chat function"""
239
  if not chat_model.model_loaded:
240
+ yield "Please wait for the model to load or reload the space."
241
  return
242
 
243
  # Convert history format
 
285
  </div>
286
  """)
287
 
288
+ # Model status indicator
289
  with gr.Row():
290
+ model_status = gr.Textbox(
291
+ label="Model Status",
292
+ value="Model loaded and ready!" if chat_model.model_loaded else "Model loading...",
293
+ interactive=False,
294
+ container=True
295
+ )
 
 
 
 
 
 
 
 
 
296
 
297
  # Configuration section
298
  with gr.Accordion("⚙️ Configuration", open=False):
 
338
  submit_btn = gr.Button("Send", variant="primary", scale=1)
339
  clear_btn = gr.Button("Clear", scale=0)
340
 
 
 
 
 
 
 
 
341
  # Handle chat submission
342
+ def handle_chat(message, history, system_prompt, temperature, streaming):
343
  if streaming:
344
+ return chat_stream_fn(message, history, system_prompt, temperature)
345
  else:
346
+ return chat_fn(message, history, system_prompt, temperature)
347
 
348
  msg.submit(
349
  handle_chat,
350
+ inputs=[msg, chatbot, system_prompt, temperature, streaming],
351
  outputs=[chatbot]
352
  )
353
 
354
  submit_btn.click(
355
  handle_chat,
356
+ inputs=[msg, chatbot, system_prompt, temperature, streaming],
357
  outputs=[chatbot]
358
  )
359
 
 
378
  # Footer
379
  gr.HTML("""
380
  <div style="text-align: center; margin-top: 20px; color: #666;">
381
+ <p>⚠️ Note: Model is pre-loaded for faster inference. GPU is allocated only during generation.</p>
382
  <p>Model: <a href="https://huggingface.co/facebook/MobileLLM-Pro" target="_blank">facebook/MobileLLM-Pro</a></p>
383
  </div>
384
  """)