Joe6636564 commited on
Commit
2949c89
·
verified ·
1 Parent(s): 9e0c9bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -157
app.py CHANGED
@@ -32,16 +32,12 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
32
  print("Loading tokenizer and model...")
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID1)
34
 
35
- # Add padding token if it doesn't exist
36
- if tokenizer.pad_token is None:
37
- tokenizer.pad_token = tokenizer.eos_token
38
-
39
  # CPU-only model loading
40
  model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_ID1,
42
- torch_dtype=torch.float32,
43
  device_map="cpu",
44
- low_cpu_mem_usage=True
45
  )
46
 
47
  # Vision model setup
@@ -53,10 +49,9 @@ try:
53
  models["microsoft/Phi-3.5-vision-instruct"] = AutoModelForCausalLM.from_pretrained(
54
  "microsoft/Phi-3.5-vision-instruct",
55
  trust_remote_code=True,
56
- torch_dtype=torch.float32,
57
  device_map="cpu",
58
- low_cpu_mem_usage=True,
59
- _attn_implementation=None
60
  ).eval()
61
 
62
  processors["microsoft/Phi-3.5-vision-instruct"] = AutoProcessor.from_pretrained(
@@ -67,20 +62,19 @@ try:
67
  except Exception as e:
68
  print(f"Error loading vision model: {e}")
69
 
70
- # Optimized chatbot function with better generation parameters
71
  def stream_chat(
72
  message: str,
73
  history: list,
74
  system_prompt: str,
75
- temperature: float = 0.7, # Lower temperature for more focused responses
76
  max_new_tokens: int = 1024,
77
- top_p: float = 0.9, # Lower top_p for less randomness
78
- top_k: int = 40, # Moderate top_k
79
- repetition_penalty: float = 1.1, # Lower repetition penalty
80
  ):
81
  print(f'message: {message}')
82
  print(f'history: {history}')
83
-
84
  conversation = [{"role": "system", "content": system_prompt}]
85
 
86
  for prompt, answer in history:
@@ -90,35 +84,18 @@ def stream_chat(
90
  ])
91
 
92
  conversation.append({"role": "user", "content": message})
 
93
 
94
- # Apply chat template
95
- input_ids = tokenizer.apply_chat_template(
96
- conversation,
97
- add_generation_prompt=True,
98
- return_tensors="pt"
99
- ).to(device)
100
-
101
- streamer = TextIteratorStreamer(
102
- tokenizer,
103
- timeout=60.0,
104
- skip_prompt=True,
105
- skip_special_tokens=True
106
- )
107
-
108
- # Optimized generation parameters to reduce repetition
109
  generate_kwargs = dict(
110
  input_ids=input_ids,
111
  max_new_tokens=max_new_tokens,
112
- temperature=temperature,
113
  top_p=top_p,
114
  top_k=top_k,
115
- repetition_penalty=repetition_penalty, # Use repetition_penalty instead of penalty
116
- do_sample=True if temperature > 0 else False,
117
- pad_token_id=tokenizer.eos_token_id,
118
- eos_token_id=[tokenizer.eos_token_id, 128001, 128008, 128009],
119
  streamer=streamer,
120
- no_repeat_ngram_size=3, # Prevent repeating n-grams
121
- early_stopping=True,
122
  )
123
 
124
  with torch.no_grad():
@@ -130,7 +107,7 @@ def stream_chat(
130
  buffer += new_text
131
  yield buffer
132
 
133
- # Optimized vision model function
134
  def stream_vision(image, text_input=None, model_id="microsoft/Phi-3.5-vision-instruct"):
135
  if model_id not in models:
136
  return "Vision model not available"
@@ -160,48 +137,39 @@ def stream_vision(image, text_input=None, model_id="microsoft/Phi-3.5-vision-ins
160
  # Process the inputs with the processor
161
  inputs = processor(prompt, images, return_tensors="pt").to(device)
162
 
163
- # Optimized generation parameters for vision model
164
  generation_args = {
165
- "max_new_tokens": 500,
166
- "temperature": 0.3, # Lower temperature for more factual responses
167
- "top_p": 0.9,
168
- "top_k": 30,
169
- "repetition_penalty": 1.1,
170
- "do_sample": True,
171
- "no_repeat_ngram_size": 3,
172
- "early_stopping": True,
173
- "eos_token_id": processor.tokenizer.eos_token_id,
174
  }
175
 
176
  # Generate the response
177
- try:
178
- generate_ids = model_vision.generate(
179
- **inputs,
180
- **generation_args
181
- )
182
-
183
- # Remove input tokens from the generated response
184
- generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
185
-
186
- # Decode the generated output
187
- response = processor.batch_decode(
188
- generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
189
- )[0]
190
-
191
- return response
192
- except Exception as e:
193
- return f"Error generating vision response: {str(e)}"
194
 
195
- # Flask API Routes with optimized parameters
196
  @flask_app.route('/health', methods=['GET'])
197
  def health_check():
198
- vision_loaded = len(models) > 0 and "microsoft/Phi-3.5-vision-instruct" in models
199
  return jsonify({
200
  "status": "healthy",
201
  "device": device,
202
  "models_loaded": {
203
  "chatbot": MODEL_ID1 in globals() and 'model' in globals(),
204
- "vision": vision_loaded
205
  }
206
  })
207
 
@@ -211,10 +179,8 @@ def api_chat():
211
  data = request.json
212
  message = data.get('message', '')
213
  system_prompt = data.get('system_prompt', 'You are a helpful assistant')
214
- temperature = data.get('temperature', 0.7) # Default to lower temperature
215
- max_new_tokens = data.get('max_new_tokens', 512)
216
- top_p = data.get('top_p', 0.9)
217
- repetition_penalty = data.get('repetition_penalty', 1.1)
218
 
219
  # Prepare conversation
220
  conversation = [{"role": "system", "content": system_prompt}]
@@ -224,26 +190,20 @@ def api_chat():
224
  conversation, add_generation_prompt=True, return_tensors="pt"
225
  ).to(device)
226
 
227
- # Generate response with optimized parameters
228
  with torch.no_grad():
229
  generate_ids = model.generate(
230
  input_ids,
231
  max_new_tokens=max_new_tokens,
232
  temperature=temperature,
233
- top_p=top_p,
234
- repetition_penalty=repetition_penalty,
235
  do_sample=temperature > 0,
236
- no_repeat_ngram_size=3,
237
- early_stopping=True,
238
- eos_token_id=[tokenizer.eos_token_id, 128001, 128008, 128009],
239
- pad_token_id=tokenizer.eos_token_id,
240
  )
241
 
242
  # Decode response
243
  response = tokenizer.decode(
244
  generate_ids[0][input_ids.shape[1]:],
245
- skip_special_tokens=True,
246
- clean_up_tokenization_spaces=True
247
  )
248
 
249
  return jsonify({
@@ -289,142 +249,126 @@ def api_vision():
289
 
290
  @flask_app.route('/api/models', methods=['GET'])
291
  def get_models():
292
- vision_loaded = len(models) > 0 and "microsoft/Phi-3.5-vision-instruct" in models
293
  return jsonify({
294
  "chat_model": MODEL_ID1,
295
- "vision_models": list(models.keys()) if vision_loaded else [],
296
- "device": device,
297
- "vision_available": vision_loaded
298
  })
299
 
300
  def run_flask():
301
  flask_app.run(host='0.0.0.0', port=5000, debug=False, threaded=True)
302
 
303
  def run_gradio():
 
304
  CSS = """.duplicate-button { margin: auto !important; color: white !important; background: black !important; border-radius: 100vh !important;}h3 { text-align: center;}"""
305
  PLACEHOLDER = """<center><p>Hi! I'm your assistant. Feel free to ask your questions</p></center>"""
306
-
307
- vision_available = len(models) > 0 and "microsoft/Phi-3.5-vision-instruct" in models
308
- vision_status = "Available" if vision_available else "Not Available"
309
-
310
- TITLE = f"<h1><center>Phi-3.5 Chatbot & Phi-3.5 Vision (Optimized CPU Version)</center></h1>"
311
- EXPLANATION = f"""<div style="text-align: center; margin-top: 20px;">
312
- <p><strong>Optimized CPU Version</strong> - Better response quality with reduced repetition</p>
313
- <p><strong>Vision Model Status:</strong> {vision_status}</p>
314
- <p><strong>Optimizations applied:</strong> Lower temperature, repetition penalty, and no-repeat n-gram size</p>
315
  </div>"""
316
  footer = """<div style="text-align: center; margin-top: 20px;">
317
- <br> Made with 💖 by Pejman Ebrahimi | Running on CPU with optimized parameters
 
 
 
 
318
  </div>"""
319
 
320
- with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
 
321
  gr.HTML(TITLE)
322
  gr.HTML(EXPLANATION)
323
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
324
 
325
  with gr.Tab("Chatbot"):
326
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, type="messages")
327
  gr.ChatInterface(
328
  fn=stream_chat,
329
  chatbot=chatbot,
330
  fill_height=True,
331
- additional_inputs_accordion=gr.Accordion(label="⚙️ Advanced Parameters", open=False),
332
  additional_inputs=[
333
  gr.Textbox(
334
- value="You are a helpful AI assistant. Provide accurate, concise, and non-repetitive responses.",
335
  label="System Prompt",
 
336
  ),
337
  gr.Slider(
338
- minimum=0.1,
339
- maximum=1.0,
340
  step=0.1,
341
- value=0.7,
342
- label="Temperature (lower = more focused)",
 
343
  ),
344
  gr.Slider(
345
  minimum=128,
346
- maximum=2048,
347
  step=1,
348
- value=512,
349
  label="Max new tokens",
 
350
  ),
351
  gr.Slider(
352
- minimum=0.5,
353
  maximum=1.0,
354
  step=0.1,
355
- value=0.9,
356
- label="Top-p (nucleus sampling)",
 
357
  ),
358
  gr.Slider(
359
  minimum=1,
360
- maximum=100,
361
  step=1,
362
- value=40,
363
- label="Top-k",
 
364
  ),
365
  gr.Slider(
366
- minimum=1.0,
367
  maximum=2.0,
368
  step=0.1,
369
- value=1.1,
370
- label="Repetition Penalty",
 
371
  ),
372
  ],
373
  examples=[
374
- ["Explain the concept of machine learning in simple terms"],
375
- ["What are the main differences between Python and JavaScript?"],
376
- ["How does photosynthesis work in plants?"],
377
- ["Write a brief summary of the history of the internet"],
378
  ],
379
  cache_examples=False,
380
  )
381
 
382
- if vision_available:
383
- with gr.Tab("Vision"):
384
- with gr.Row():
385
- input_img = gr.Image(label="Input Picture")
386
- with gr.Row():
387
- model_selector = gr.Dropdown(
388
- choices=list(models.keys()),
389
- label="Model",
390
- value="microsoft/Phi-3.5-vision-instruct",
391
- allow_custom_value=False
392
- )
393
- with gr.Row():
394
- text_input = gr.Textbox(
395
- label="Question",
396
- value="Describe what you see in this image in detail without repetition.",
397
- placeholder="Ask a specific question about the image..."
398
- )
399
- with gr.Row():
400
- submit_btn = gr.Button(value="Analyze Image")
401
- with gr.Row():
402
- output_text = gr.Textbox(
403
- label="Analysis Result",
404
- lines=5
405
- )
406
-
407
- submit_btn.click(stream_vision, [input_img, text_input, model_selector], [output_text])
408
- else:
409
- with gr.Tab("Vision"):
410
- gr.HTML("""<div style="text-align: center; padding: 40px;">
411
- <h3>Vision Model Not Available</h3>
412
- <p>The vision model failed to load due to memory constraints.</p>
413
- </div>""")
414
 
415
  gr.HTML(footer)
416
 
 
417
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
418
 
419
  if __name__ == "__main__":
420
- print("=" * 50)
421
- print("Application Starting Up...")
422
- print(f"Device: {device}")
423
- print(f"Chat model loaded: {MODEL_ID1}")
424
- print(f"Vision model loaded: {len(models) > 0}")
425
- print("=" * 50)
426
-
427
  flask_thread = threading.Thread(target=run_flask, daemon=True)
428
  flask_thread.start()
429
 
 
430
  run_gradio()
 
32
  print("Loading tokenizer and model...")
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID1)
34
 
 
 
 
 
35
  # CPU-only model loading
36
  model = AutoModelForCausalLM.from_pretrained(
37
  MODEL_ID1,
38
+ torch_dtype=torch.float32, # Use float32 for CPU
39
  device_map="cpu",
40
+ low_cpu_mem_usage=True # Optimize for CPU memory
41
  )
42
 
43
  # Vision model setup
 
49
  models["microsoft/Phi-3.5-vision-instruct"] = AutoModelForCausalLM.from_pretrained(
50
  "microsoft/Phi-3.5-vision-instruct",
51
  trust_remote_code=True,
52
+ torch_dtype=torch.float32, # Use float32 for CPU
53
  device_map="cpu",
54
+ low_cpu_mem_usage=True # Optimize for CPU memory
 
55
  ).eval()
56
 
57
  processors["microsoft/Phi-3.5-vision-instruct"] = AutoProcessor.from_pretrained(
 
62
  except Exception as e:
63
  print(f"Error loading vision model: {e}")
64
 
65
+ # Chatbot function
66
  def stream_chat(
67
  message: str,
68
  history: list,
69
  system_prompt: str,
70
+ temperature: float = 0.8,
71
  max_new_tokens: int = 1024,
72
+ top_p: float = 1.0,
73
+ top_k: int = 20,
74
+ penalty: float = 1.2,
75
  ):
76
  print(f'message: {message}')
77
  print(f'history: {history}')
 
78
  conversation = [{"role": "system", "content": system_prompt}]
79
 
80
  for prompt, answer in history:
 
84
  ])
85
 
86
  conversation.append({"role": "user", "content": message})
87
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
88
 
89
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  generate_kwargs = dict(
91
  input_ids=input_ids,
92
  max_new_tokens=max_new_tokens,
93
+ do_sample=False if temperature == 0 else True,
94
  top_p=top_p,
95
  top_k=top_k,
96
+ temperature=temperature,
97
+ eos_token_id=[128001,128008,128009],
 
 
98
  streamer=streamer,
 
 
99
  )
100
 
101
  with torch.no_grad():
 
107
  buffer += new_text
108
  yield buffer
109
 
110
+ # Vision model function
111
  def stream_vision(image, text_input=None, model_id="microsoft/Phi-3.5-vision-instruct"):
112
  if model_id not in models:
113
  return "Vision model not available"
 
137
  # Process the inputs with the processor
138
  inputs = processor(prompt, images, return_tensors="pt").to(device)
139
 
140
+ # Generation parameters
141
  generation_args = {
142
+ "max_new_tokens": 500, # Reduced for CPU
143
+ "temperature": 0.0,
144
+ "do_sample": False,
 
 
 
 
 
 
145
  }
146
 
147
  # Generate the response
148
+ generate_ids = model_vision.generate(
149
+ **inputs,
150
+ eos_token_id=processor.tokenizer.eos_token_id,
151
+ **generation_args
152
+ )
153
+
154
+ # Remove input tokens from the generated response
155
+ generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
156
+
157
+ # Decode the generated output
158
+ response = processor.batch_decode(
159
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
160
+ )[0]
161
+
162
+ return response
 
 
163
 
164
+ # Flask API Routes
165
  @flask_app.route('/health', methods=['GET'])
166
  def health_check():
 
167
  return jsonify({
168
  "status": "healthy",
169
  "device": device,
170
  "models_loaded": {
171
  "chatbot": MODEL_ID1 in globals() and 'model' in globals(),
172
+ "vision": len(models) > 0
173
  }
174
  })
175
 
 
179
  data = request.json
180
  message = data.get('message', '')
181
  system_prompt = data.get('system_prompt', 'You are a helpful assistant')
182
+ temperature = data.get('temperature', 0.8)
183
+ max_new_tokens = data.get('max_new_tokens', 512) # Reduced for CPU
 
 
184
 
185
  # Prepare conversation
186
  conversation = [{"role": "system", "content": system_prompt}]
 
190
  conversation, add_generation_prompt=True, return_tensors="pt"
191
  ).to(device)
192
 
193
+ # Generate response
194
  with torch.no_grad():
195
  generate_ids = model.generate(
196
  input_ids,
197
  max_new_tokens=max_new_tokens,
198
  temperature=temperature,
 
 
199
  do_sample=temperature > 0,
200
+ eos_token_id=[128001, 128008, 128009]
 
 
 
201
  )
202
 
203
  # Decode response
204
  response = tokenizer.decode(
205
  generate_ids[0][input_ids.shape[1]:],
206
+ skip_special_tokens=True
 
207
  )
208
 
209
  return jsonify({
 
249
 
250
  @flask_app.route('/api/models', methods=['GET'])
251
  def get_models():
 
252
  return jsonify({
253
  "chat_model": MODEL_ID1,
254
+ "vision_models": list(models.keys()),
255
+ "device": device
 
256
  })
257
 
258
  def run_flask():
259
  flask_app.run(host='0.0.0.0', port=5000, debug=False, threaded=True)
260
 
261
  def run_gradio():
262
+ # CSS for the interface
263
  CSS = """.duplicate-button { margin: auto !important; color: white !important; background: black !important; border-radius: 100vh !important;}h3 { text-align: center;}"""
264
  PLACEHOLDER = """<center><p>Hi! I'm your assistant. Feel free to ask your questions</p></center>"""
265
+ TITLE = "<h1><center>Phi-3.5 Chatbot & Phi-3.5 Vision (CPU Version)</center></h1>"
266
+ EXPLANATION = """<div style="text-align: center; margin-top: 20px;">
267
+ <p><strong>CPU-Only Version</strong> - This instance is running on CPU. Responses may be slower than GPU-accelerated versions.</p>
268
+ <p>This app supports both the microsoft/Phi-3.5-mini-instruct model for chat bot and the microsoft/Phi-3.5-vision-instruct model for multimodal model.</p>
269
+ <p>Phi-3.5-vision is a lightweight, state-of-the-art open multimodal model built upon datasets which include - synthetic data and filtered publicly available websites - with a focus on very high-quality, reasoning dense data both on text and vision. The model belongs to the Phi-3 model family, and the multimodal version comes with 128K context length (in tokens) it can support.</p>
270
+ <p>Phi-3.5-mini is a lightweight, state-of-the-art open model built upon datasets used for Phi-3 - synthetic data and filtered publicly available websites - with a focus on very high-quality, reasoning dense data. The model belongs to the Phi-3 model family and supports 128K token context length.</p>
 
 
 
271
  </div>"""
272
  footer = """<div style="text-align: center; margin-top: 20px;">
273
+ <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
274
+ <a href="https://github.com/arad1367" target="_blank">GitHub</a> |
275
+ <a href="https://huggingface.co/microsoft/Phi-3.5-mini-instruct" target="_blank">microsoft/Phi-3.5-mini-instruct</a> |
276
+ <a href="https://huggingface.co/microsoft/Phi-3.5-vision-instruct" target="_blank">microsoft/Phi-3.5-vision-instruct</a>
277
+ <br> Made with 💖 by Pejman Ebrahimi | Running on CPU
278
  </div>"""
279
 
280
+ # Gradio app with two tabs
281
+ with gr.Blocks(css=CSS, theme="small_and_pretty") as demo:
282
  gr.HTML(TITLE)
283
  gr.HTML(EXPLANATION)
284
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
285
 
286
  with gr.Tab("Chatbot"):
287
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
288
  gr.ChatInterface(
289
  fn=stream_chat,
290
  chatbot=chatbot,
291
  fill_height=True,
292
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
293
  additional_inputs=[
294
  gr.Textbox(
295
+ value="You are a helpful assistant",
296
  label="System Prompt",
297
+ render=False,
298
  ),
299
  gr.Slider(
300
+ minimum=0,
301
+ maximum=1,
302
  step=0.1,
303
+ value=0.8,
304
+ label="Temperature",
305
+ render=False,
306
  ),
307
  gr.Slider(
308
  minimum=128,
309
+ maximum=2048, # Reduced for CPU
310
  step=1,
311
+ value=512, # Reduced for CPU
312
  label="Max new tokens",
313
+ render=False,
314
  ),
315
  gr.Slider(
316
+ minimum=0.0,
317
  maximum=1.0,
318
  step=0.1,
319
+ value=1.0,
320
+ label="top_p",
321
+ render=False,
322
  ),
323
  gr.Slider(
324
  minimum=1,
325
+ maximum=20,
326
  step=1,
327
+ value=20,
328
+ label="top_k",
329
+ render=False,
330
  ),
331
  gr.Slider(
332
+ minimum=0.0,
333
  maximum=2.0,
334
  step=0.1,
335
+ value=1.2,
336
+ label="Repetition penalty",
337
+ render=False,
338
  ),
339
  ],
340
  examples=[
341
+ ["Hello, how are you?"],
342
+ ["Explain quantum computing in simple terms"],
343
+ ["What are the benefits of renewable energy?"],
344
+ ["Write a short poem about technology"],
345
  ],
346
  cache_examples=False,
347
  )
348
 
349
+ with gr.Tab("Vision"):
350
+ with gr.Row():
351
+ input_img = gr.Image(label="Input Picture")
352
+ with gr.Row():
353
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="microsoft/Phi-3.5-vision-instruct")
354
+ with gr.Row():
355
+ text_input = gr.Textbox(label="Question", value="What's in this image?")
356
+ with gr.Row():
357
+ submit_btn = gr.Button(value="Submit")
358
+ with gr.Row():
359
+ output_text = gr.Textbox(label="Output Text")
360
+
361
+ submit_btn.click(stream_vision, [input_img, text_input, model_selector], [output_text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  gr.HTML(footer)
364
 
365
+ # Launch the Gradio app
366
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
367
 
368
  if __name__ == "__main__":
369
+ # Start Flask server in a separate thread
 
 
 
 
 
 
370
  flask_thread = threading.Thread(target=run_flask, daemon=True)
371
  flask_thread.start()
372
 
373
+ # Run Gradio in main thread
374
  run_gradio()