Joe6636564 commited on
Commit
9e0c9bc
·
verified ·
1 Parent(s): 0d19a07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -98
app.py CHANGED
@@ -32,28 +32,30 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
32
  print("Loading tokenizer and model...")
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID1)
34
 
 
 
 
 
35
  # CPU-only model loading
36
  model = AutoModelForCausalLM.from_pretrained(
37
  MODEL_ID1,
38
- torch_dtype=torch.float32, # Use float32 for CPU
39
  device_map="cpu",
40
- low_cpu_mem_usage=True # Optimize for CPU memory
41
  )
42
 
43
- # Vision model setup - FIXED for CPU
44
  print("Loading vision models...")
45
  models = {}
46
  processors = {}
47
 
48
  try:
49
- # Load vision model without flash_attention_2 for CPU
50
  models["microsoft/Phi-3.5-vision-instruct"] = AutoModelForCausalLM.from_pretrained(
51
  "microsoft/Phi-3.5-vision-instruct",
52
  trust_remote_code=True,
53
- torch_dtype=torch.float32, # Use float32 for CPU
54
  device_map="cpu",
55
- low_cpu_mem_usage=True, # Optimize for CPU memory
56
- # Remove flash_attention_2 for CPU compatibility
57
  _attn_implementation=None
58
  ).eval()
59
 
@@ -64,37 +66,21 @@ try:
64
  print("Vision model loaded successfully on CPU")
65
  except Exception as e:
66
  print(f"Error loading vision model: {e}")
67
- # Try alternative loading method
68
- try:
69
- print("Trying alternative loading method for vision model...")
70
- models["microsoft/Phi-3.5-vision-instruct"] = AutoModelForCausalLM.from_pretrained(
71
- "microsoft/Phi-3.5-vision-instruct",
72
- trust_remote_code=True,
73
- torch_dtype=torch.float32,
74
- device_map="cpu"
75
- ).eval()
76
-
77
- processors["microsoft/Phi-3.5-vision-instruct"] = AutoProcessor.from_pretrained(
78
- "microsoft/Phi-3.5-vision-instruct",
79
- trust_remote_code=True
80
- )
81
- print("Vision model loaded successfully with alternative method")
82
- except Exception as e2:
83
- print(f"Failed to load vision model with alternative method: {e2}")
84
 
85
- # Chatbot function
86
  def stream_chat(
87
  message: str,
88
  history: list,
89
  system_prompt: str,
90
- temperature: float = 0.8,
91
  max_new_tokens: int = 1024,
92
- top_p: float = 1.0,
93
- top_k: int = 20,
94
- penalty: float = 1.2,
95
  ):
96
  print(f'message: {message}')
97
  print(f'history: {history}')
 
98
  conversation = [{"role": "system", "content": system_prompt}]
99
 
100
  for prompt, answer in history:
@@ -104,18 +90,35 @@ def stream_chat(
104
  ])
105
 
106
  conversation.append({"role": "user", "content": message})
107
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
108
 
109
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  generate_kwargs = dict(
111
  input_ids=input_ids,
112
  max_new_tokens=max_new_tokens,
113
- do_sample=False if temperature == 0 else True,
114
  top_p=top_p,
115
  top_k=top_k,
116
- temperature=temperature,
117
- eos_token_id=[128001,128008,128009],
 
 
118
  streamer=streamer,
 
 
119
  )
120
 
121
  with torch.no_grad():
@@ -127,7 +130,7 @@ def stream_chat(
127
  buffer += new_text
128
  yield buffer
129
 
130
- # Vision model function
131
  def stream_vision(image, text_input=None, model_id="microsoft/Phi-3.5-vision-instruct"):
132
  if model_id not in models:
133
  return "Vision model not available"
@@ -157,18 +160,23 @@ def stream_vision(image, text_input=None, model_id="microsoft/Phi-3.5-vision-ins
157
  # Process the inputs with the processor
158
  inputs = processor(prompt, images, return_tensors="pt").to(device)
159
 
160
- # Generation parameters - reduced for CPU
161
  generation_args = {
162
- "max_new_tokens": 300, # Further reduced for CPU
163
- "temperature": 0.0,
164
- "do_sample": False,
 
 
 
 
 
 
165
  }
166
 
167
  # Generate the response
168
  try:
169
  generate_ids = model_vision.generate(
170
  **inputs,
171
- eos_token_id=processor.tokenizer.eos_token_id,
172
  **generation_args
173
  )
174
 
@@ -184,7 +192,7 @@ def stream_vision(image, text_input=None, model_id="microsoft/Phi-3.5-vision-ins
184
  except Exception as e:
185
  return f"Error generating vision response: {str(e)}"
186
 
187
- # Flask API Routes
188
  @flask_app.route('/health', methods=['GET'])
189
  def health_check():
190
  vision_loaded = len(models) > 0 and "microsoft/Phi-3.5-vision-instruct" in models
@@ -203,8 +211,10 @@ def api_chat():
203
  data = request.json
204
  message = data.get('message', '')
205
  system_prompt = data.get('system_prompt', 'You are a helpful assistant')
206
- temperature = data.get('temperature', 0.8)
207
- max_new_tokens = data.get('max_new_tokens', 512) # Reduced for CPU
 
 
208
 
209
  # Prepare conversation
210
  conversation = [{"role": "system", "content": system_prompt}]
@@ -214,20 +224,26 @@ def api_chat():
214
  conversation, add_generation_prompt=True, return_tensors="pt"
215
  ).to(device)
216
 
217
- # Generate response
218
  with torch.no_grad():
219
  generate_ids = model.generate(
220
  input_ids,
221
  max_new_tokens=max_new_tokens,
222
  temperature=temperature,
 
 
223
  do_sample=temperature > 0,
224
- eos_token_id=[128001, 128008, 128009]
 
 
 
225
  )
226
 
227
  # Decode response
228
  response = tokenizer.decode(
229
  generate_ids[0][input_ids.shape[1]:],
230
- skip_special_tokens=True
 
231
  )
232
 
233
  return jsonify({
@@ -285,100 +301,84 @@ def run_flask():
285
  flask_app.run(host='0.0.0.0', port=5000, debug=False, threaded=True)
286
 
287
  def run_gradio():
288
- # CSS for the interface
289
  CSS = """.duplicate-button { margin: auto !important; color: white !important; background: black !important; border-radius: 100vh !important;}h3 { text-align: center;}"""
290
  PLACEHOLDER = """<center><p>Hi! I'm your assistant. Feel free to ask your questions</p></center>"""
291
 
292
- # Check if vision model is available
293
  vision_available = len(models) > 0 and "microsoft/Phi-3.5-vision-instruct" in models
294
  vision_status = "Available" if vision_available else "Not Available"
295
 
296
- TITLE = f"<h1><center>Phi-3.5 Chatbot & Phi-3.5 Vision (CPU Version)</center></h1>"
297
  EXPLANATION = f"""<div style="text-align: center; margin-top: 20px;">
298
- <p><strong>CPU-Only Version</strong> - This instance is running on CPU. Responses may be slower than GPU-accelerated versions.</p>
299
  <p><strong>Vision Model Status:</strong> {vision_status}</p>
300
- <p>This app supports both the microsoft/Phi-3.5-mini-instruct model for chat bot and the microsoft/Phi-3.5-vision-instruct model for multimodal model.</p>
301
- <p>Phi-3.5-vision is a lightweight, state-of-the-art open multimodal model built upon datasets which include - synthetic data and filtered publicly available websites - with a focus on very high-quality, reasoning dense data both on text and vision.</p>
302
- <p>Phi-3.5-mini is a lightweight, state-of-the-art open model built upon datasets used for Phi-3 - synthetic data and filtered publicly available websites - with a focus on very high-quality, reasoning dense data.</p>
303
  </div>"""
304
  footer = """<div style="text-align: center; margin-top: 20px;">
305
- <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
306
- <a href="https://github.com/arad1367" target="_blank">GitHub</a> |
307
- <a href="https://huggingface.co/microsoft/Phi-3.5-mini-instruct" target="_blank">microsoft/Phi-3.5-mini-instruct</a> |
308
- <a href="https://huggingface.co/microsoft/Phi-3.5-vision-instruct" target="_blank">microsoft/Phi-3.5-vision-instruct</a>
309
- <br> Made with 💖 by Pejman Ebrahimi | Running on CPU
310
  </div>"""
311
 
312
- # Gradio app with two tabs
313
- with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo: # Changed to default theme
314
  gr.HTML(TITLE)
315
  gr.HTML(EXPLANATION)
316
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
317
 
318
  with gr.Tab("Chatbot"):
319
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, type="messages") # Fixed deprecated type
320
  gr.ChatInterface(
321
  fn=stream_chat,
322
  chatbot=chatbot,
323
  fill_height=True,
324
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
325
  additional_inputs=[
326
  gr.Textbox(
327
- value="You are a helpful assistant",
328
  label="System Prompt",
329
- render=False,
330
  ),
331
  gr.Slider(
332
- minimum=0,
333
- maximum=1,
334
  step=0.1,
335
- value=0.8,
336
- label="Temperature",
337
- render=False,
338
  ),
339
  gr.Slider(
340
  minimum=128,
341
- maximum=2048, # Reduced for CPU
342
  step=1,
343
- value=512, # Reduced for CPU
344
  label="Max new tokens",
345
- render=False,
346
  ),
347
  gr.Slider(
348
- minimum=0.0,
349
  maximum=1.0,
350
  step=0.1,
351
- value=1.0,
352
- label="top_p",
353
- render=False,
354
  ),
355
  gr.Slider(
356
  minimum=1,
357
- maximum=20,
358
  step=1,
359
- value=20,
360
- label="top_k",
361
- render=False,
362
  ),
363
  gr.Slider(
364
- minimum=0.0,
365
  maximum=2.0,
366
  step=0.1,
367
- value=1.2,
368
- label="Repetition penalty",
369
- render=False,
370
  ),
371
  ],
372
  examples=[
373
- ["Hello, how are you?"],
374
- ["Explain quantum computing in simple terms"],
375
- ["What are the benefits of renewable energy?"],
376
- ["Write a short poem about technology"],
377
  ],
378
  cache_examples=False,
379
  )
380
 
381
- # Only show vision tab if model is available
382
  if vision_available:
383
  with gr.Tab("Vision"):
384
  with gr.Row():
@@ -388,27 +388,32 @@ def run_gradio():
388
  choices=list(models.keys()),
389
  label="Model",
390
  value="microsoft/Phi-3.5-vision-instruct",
391
- allow_custom_value=False # Fixed warning
392
  )
393
  with gr.Row():
394
- text_input = gr.Textbox(label="Question", value="What's in this image?")
 
 
 
 
395
  with gr.Row():
396
- submit_btn = gr.Button(value="Submit")
397
  with gr.Row():
398
- output_text = gr.Textbox(label="Output Text")
 
 
 
399
 
400
  submit_btn.click(stream_vision, [input_img, text_input, model_selector], [output_text])
401
  else:
402
  with gr.Tab("Vision"):
403
  gr.HTML("""<div style="text-align: center; padding: 40px;">
404
  <h3>Vision Model Not Available</h3>
405
- <p>The vision model failed to load. This is likely due to memory constraints on CPU.</p>
406
- <p>Try using the chat model instead, or run this on a system with more RAM.</p>
407
  </div>""")
408
 
409
  gr.HTML(footer)
410
 
411
- # Launch the Gradio app
412
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
413
 
414
  if __name__ == "__main__":
@@ -419,9 +424,7 @@ if __name__ == "__main__":
419
  print(f"Vision model loaded: {len(models) > 0}")
420
  print("=" * 50)
421
 
422
- # Start Flask server in a separate thread
423
  flask_thread = threading.Thread(target=run_flask, daemon=True)
424
  flask_thread.start()
425
 
426
- # Run Gradio in main thread
427
  run_gradio()
 
32
  print("Loading tokenizer and model...")
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID1)
34
 
35
+ # Add padding token if it doesn't exist
36
+ if tokenizer.pad_token is None:
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+
39
  # CPU-only model loading
40
  model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_ID1,
42
+ torch_dtype=torch.float32,
43
  device_map="cpu",
44
+ low_cpu_mem_usage=True
45
  )
46
 
47
+ # Vision model setup
48
  print("Loading vision models...")
49
  models = {}
50
  processors = {}
51
 
52
  try:
 
53
  models["microsoft/Phi-3.5-vision-instruct"] = AutoModelForCausalLM.from_pretrained(
54
  "microsoft/Phi-3.5-vision-instruct",
55
  trust_remote_code=True,
56
+ torch_dtype=torch.float32,
57
  device_map="cpu",
58
+ low_cpu_mem_usage=True,
 
59
  _attn_implementation=None
60
  ).eval()
61
 
 
66
  print("Vision model loaded successfully on CPU")
67
  except Exception as e:
68
  print(f"Error loading vision model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # Optimized chatbot function with better generation parameters
71
  def stream_chat(
72
  message: str,
73
  history: list,
74
  system_prompt: str,
75
+ temperature: float = 0.7, # Lower temperature for more focused responses
76
  max_new_tokens: int = 1024,
77
+ top_p: float = 0.9, # Lower top_p for less randomness
78
+ top_k: int = 40, # Moderate top_k
79
+ repetition_penalty: float = 1.1, # Lower repetition penalty
80
  ):
81
  print(f'message: {message}')
82
  print(f'history: {history}')
83
+
84
  conversation = [{"role": "system", "content": system_prompt}]
85
 
86
  for prompt, answer in history:
 
90
  ])
91
 
92
  conversation.append({"role": "user", "content": message})
 
93
 
94
+ # Apply chat template
95
+ input_ids = tokenizer.apply_chat_template(
96
+ conversation,
97
+ add_generation_prompt=True,
98
+ return_tensors="pt"
99
+ ).to(device)
100
+
101
+ streamer = TextIteratorStreamer(
102
+ tokenizer,
103
+ timeout=60.0,
104
+ skip_prompt=True,
105
+ skip_special_tokens=True
106
+ )
107
+
108
+ # Optimized generation parameters to reduce repetition
109
  generate_kwargs = dict(
110
  input_ids=input_ids,
111
  max_new_tokens=max_new_tokens,
112
+ temperature=temperature,
113
  top_p=top_p,
114
  top_k=top_k,
115
+ repetition_penalty=repetition_penalty, # Use repetition_penalty instead of penalty
116
+ do_sample=True if temperature > 0 else False,
117
+ pad_token_id=tokenizer.eos_token_id,
118
+ eos_token_id=[tokenizer.eos_token_id, 128001, 128008, 128009],
119
  streamer=streamer,
120
+ no_repeat_ngram_size=3, # Prevent repeating n-grams
121
+ early_stopping=True,
122
  )
123
 
124
  with torch.no_grad():
 
130
  buffer += new_text
131
  yield buffer
132
 
133
+ # Optimized vision model function
134
  def stream_vision(image, text_input=None, model_id="microsoft/Phi-3.5-vision-instruct"):
135
  if model_id not in models:
136
  return "Vision model not available"
 
160
  # Process the inputs with the processor
161
  inputs = processor(prompt, images, return_tensors="pt").to(device)
162
 
163
+ # Optimized generation parameters for vision model
164
  generation_args = {
165
+ "max_new_tokens": 500,
166
+ "temperature": 0.3, # Lower temperature for more factual responses
167
+ "top_p": 0.9,
168
+ "top_k": 30,
169
+ "repetition_penalty": 1.1,
170
+ "do_sample": True,
171
+ "no_repeat_ngram_size": 3,
172
+ "early_stopping": True,
173
+ "eos_token_id": processor.tokenizer.eos_token_id,
174
  }
175
 
176
  # Generate the response
177
  try:
178
  generate_ids = model_vision.generate(
179
  **inputs,
 
180
  **generation_args
181
  )
182
 
 
192
  except Exception as e:
193
  return f"Error generating vision response: {str(e)}"
194
 
195
+ # Flask API Routes with optimized parameters
196
  @flask_app.route('/health', methods=['GET'])
197
  def health_check():
198
  vision_loaded = len(models) > 0 and "microsoft/Phi-3.5-vision-instruct" in models
 
211
  data = request.json
212
  message = data.get('message', '')
213
  system_prompt = data.get('system_prompt', 'You are a helpful assistant')
214
+ temperature = data.get('temperature', 0.7) # Default to lower temperature
215
+ max_new_tokens = data.get('max_new_tokens', 512)
216
+ top_p = data.get('top_p', 0.9)
217
+ repetition_penalty = data.get('repetition_penalty', 1.1)
218
 
219
  # Prepare conversation
220
  conversation = [{"role": "system", "content": system_prompt}]
 
224
  conversation, add_generation_prompt=True, return_tensors="pt"
225
  ).to(device)
226
 
227
+ # Generate response with optimized parameters
228
  with torch.no_grad():
229
  generate_ids = model.generate(
230
  input_ids,
231
  max_new_tokens=max_new_tokens,
232
  temperature=temperature,
233
+ top_p=top_p,
234
+ repetition_penalty=repetition_penalty,
235
  do_sample=temperature > 0,
236
+ no_repeat_ngram_size=3,
237
+ early_stopping=True,
238
+ eos_token_id=[tokenizer.eos_token_id, 128001, 128008, 128009],
239
+ pad_token_id=tokenizer.eos_token_id,
240
  )
241
 
242
  # Decode response
243
  response = tokenizer.decode(
244
  generate_ids[0][input_ids.shape[1]:],
245
+ skip_special_tokens=True,
246
+ clean_up_tokenization_spaces=True
247
  )
248
 
249
  return jsonify({
 
301
  flask_app.run(host='0.0.0.0', port=5000, debug=False, threaded=True)
302
 
303
  def run_gradio():
 
304
  CSS = """.duplicate-button { margin: auto !important; color: white !important; background: black !important; border-radius: 100vh !important;}h3 { text-align: center;}"""
305
  PLACEHOLDER = """<center><p>Hi! I'm your assistant. Feel free to ask your questions</p></center>"""
306
 
 
307
  vision_available = len(models) > 0 and "microsoft/Phi-3.5-vision-instruct" in models
308
  vision_status = "Available" if vision_available else "Not Available"
309
 
310
+ TITLE = f"<h1><center>Phi-3.5 Chatbot & Phi-3.5 Vision (Optimized CPU Version)</center></h1>"
311
  EXPLANATION = f"""<div style="text-align: center; margin-top: 20px;">
312
+ <p><strong>Optimized CPU Version</strong> - Better response quality with reduced repetition</p>
313
  <p><strong>Vision Model Status:</strong> {vision_status}</p>
314
+ <p><strong>Optimizations applied:</strong> Lower temperature, repetition penalty, and no-repeat n-gram size</p>
 
 
315
  </div>"""
316
  footer = """<div style="text-align: center; margin-top: 20px;">
317
+ <br> Made with 💖 by Pejman Ebrahimi | Running on CPU with optimized parameters
 
 
 
 
318
  </div>"""
319
 
320
+ with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
 
321
  gr.HTML(TITLE)
322
  gr.HTML(EXPLANATION)
323
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
324
 
325
  with gr.Tab("Chatbot"):
326
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, type="messages")
327
  gr.ChatInterface(
328
  fn=stream_chat,
329
  chatbot=chatbot,
330
  fill_height=True,
331
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Advanced Parameters", open=False),
332
  additional_inputs=[
333
  gr.Textbox(
334
+ value="You are a helpful AI assistant. Provide accurate, concise, and non-repetitive responses.",
335
  label="System Prompt",
 
336
  ),
337
  gr.Slider(
338
+ minimum=0.1,
339
+ maximum=1.0,
340
  step=0.1,
341
+ value=0.7,
342
+ label="Temperature (lower = more focused)",
 
343
  ),
344
  gr.Slider(
345
  minimum=128,
346
+ maximum=2048,
347
  step=1,
348
+ value=512,
349
  label="Max new tokens",
 
350
  ),
351
  gr.Slider(
352
+ minimum=0.5,
353
  maximum=1.0,
354
  step=0.1,
355
+ value=0.9,
356
+ label="Top-p (nucleus sampling)",
 
357
  ),
358
  gr.Slider(
359
  minimum=1,
360
+ maximum=100,
361
  step=1,
362
+ value=40,
363
+ label="Top-k",
 
364
  ),
365
  gr.Slider(
366
+ minimum=1.0,
367
  maximum=2.0,
368
  step=0.1,
369
+ value=1.1,
370
+ label="Repetition Penalty",
 
371
  ),
372
  ],
373
  examples=[
374
+ ["Explain the concept of machine learning in simple terms"],
375
+ ["What are the main differences between Python and JavaScript?"],
376
+ ["How does photosynthesis work in plants?"],
377
+ ["Write a brief summary of the history of the internet"],
378
  ],
379
  cache_examples=False,
380
  )
381
 
 
382
  if vision_available:
383
  with gr.Tab("Vision"):
384
  with gr.Row():
 
388
  choices=list(models.keys()),
389
  label="Model",
390
  value="microsoft/Phi-3.5-vision-instruct",
391
+ allow_custom_value=False
392
  )
393
  with gr.Row():
394
+ text_input = gr.Textbox(
395
+ label="Question",
396
+ value="Describe what you see in this image in detail without repetition.",
397
+ placeholder="Ask a specific question about the image..."
398
+ )
399
  with gr.Row():
400
+ submit_btn = gr.Button(value="Analyze Image")
401
  with gr.Row():
402
+ output_text = gr.Textbox(
403
+ label="Analysis Result",
404
+ lines=5
405
+ )
406
 
407
  submit_btn.click(stream_vision, [input_img, text_input, model_selector], [output_text])
408
  else:
409
  with gr.Tab("Vision"):
410
  gr.HTML("""<div style="text-align: center; padding: 40px;">
411
  <h3>Vision Model Not Available</h3>
412
+ <p>The vision model failed to load due to memory constraints.</p>
 
413
  </div>""")
414
 
415
  gr.HTML(footer)
416
 
 
417
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
418
 
419
  if __name__ == "__main__":
 
424
  print(f"Vision model loaded: {len(models) > 0}")
425
  print("=" * 50)
426
 
 
427
  flask_thread = threading.Thread(target=run_flask, daemon=True)
428
  flask_thread.start()
429
 
 
430
  run_gradio()