Joe6636564 commited on
Commit
0cecd30
·
verified ·
1 Parent(s): 4b77aa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +381 -42
app.py CHANGED
@@ -1,57 +1,396 @@
1
- from flask import Flask, request, jsonify
2
- from transformers import AutoProcessor, AutoModelForVision2Seq
3
- from PIL import Image
4
  import torch
5
- import io
6
- from flask_cors import CORS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- app = Flask(__name__)
9
- CORS(app)
10
- model_id = "microsoft/Phi-3.5-mini-instruct"
 
11
 
12
- # Load processor + model
13
- processor = AutoProcessor.from_pretrained(model_id)
14
- model = AutoModelForVision2Seq.from_pretrained(
15
- model_id,
16
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
17
- device_map="auto"
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- @app.route("/")
21
- def home():
22
- return jsonify({"message": "✅ Phi-3 Vision Flask Endpoint Running"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Text-only
25
- @app.route("/chat", methods=["POST"])
26
- def chat():
27
- data = request.json
28
- text = data.get("text")
 
 
 
29
 
30
- if not text:
31
- return jsonify({"error": "No text provided"}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- inputs = processor(text=text, return_tensors="pt").to(model.device)
34
- output = model.generate(**inputs, max_new_tokens=150)
35
- response = processor.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
36
 
37
- return jsonify({"response": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Vision + Text
40
- @app.route("/vision", methods=["POST"])
41
- def vision():
42
- if "image" not in request.files or "text" not in request.form:
43
- return jsonify({"error": "Send `image` (file) and `text` (string)."}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- text = request.form["text"]
46
- image_file = request.files["image"]
 
 
 
 
 
47
 
48
- image = Image.open(io.BytesIO(image_file.read())).convert("RGB")
 
49
 
50
- inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
51
- output = model.generate(**inputs, max_new_tokens=150)
52
- response = processor.decode(output[0], skip_special_tokens=True)
 
 
 
 
53
 
54
- return jsonify({"response": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  if __name__ == "__main__":
57
- app.run(host="0.0.0.0", port=7860)
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import time
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, AutoProcessor
6
+ import gradio as gr
7
+ from threading import Thread
8
+ from PIL import Image
9
+ import subprocess
10
+ from flask import Flask, request, jsonify
11
+ import threading
12
+
13
+ # Install flash-attn if not already installed
14
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
15
+
16
+ # Initialize Flask app
17
+ flask_app = Flask(__name__)
18
+
19
+ # Device detection
20
+ def get_device():
21
+ if torch.cuda.is_available():
22
+ device = "cuda"
23
+ # Check for CUDA version and capabilities
24
+ cuda_version = torch.version.cuda
25
+ print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
26
+ print(f"CUDA version: {cuda_version}")
27
+ else:
28
+ device = "cpu"
29
+ print("Using CPU")
30
+ return device
31
+
32
+ device = get_device()
33
+
34
+ # Model and tokenizer for the chatbot
35
+ MODEL_ID1 = "microsoft/Phi-3.5-mini-instruct"
36
+ MODEL_LIST1 = ["microsoft/Phi-3.5-mini-instruct"]
37
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
38
+
39
+ # Configure quantization based on device
40
+ if device == "cuda":
41
+ quantization_config = BitsAndBytesConfig(
42
+ load_in_4bit=True,
43
+ bnb_4bit_compute_dtype=torch.bfloat16,
44
+ bnb_4bit_use_double_quant=True,
45
+ bnb_4bit_quant_type="nf4"
46
+ )
47
+ else:
48
+ quantization_config = None
49
+
50
+ print("Loading tokenizer and model...")
51
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID1)
52
+
53
+ if device == "cuda":
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ MODEL_ID1,
56
+ torch_dtype=torch.bfloat16,
57
+ device_map="auto",
58
+ quantization_config=quantization_config
59
+ )
60
+ else:
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ MODEL_ID1,
63
+ torch_dtype=torch.float32,
64
+ device_map="cpu"
65
+ )
66
 
67
+ # Vision model setup
68
+ print("Loading vision models...")
69
+ models = {}
70
+ processors = {}
71
 
72
+ try:
73
+ if device == "cuda":
74
+ models["microsoft/Phi-3.5-vision-instruct"] = AutoModelForCausalLM.from_pretrained(
75
+ "microsoft/Phi-3.5-vision-instruct",
76
+ trust_remote_code=True,
77
+ torch_dtype="auto",
78
+ _attn_implementation="flash_attention_2",
79
+ device_map="auto"
80
+ ).eval()
81
+ else:
82
+ models["microsoft/Phi-3.5-vision-instruct"] = AutoModelForCausalLM.from_pretrained(
83
+ "microsoft/Phi-3.5-vision-instruct",
84
+ trust_remote_code=True,
85
+ torch_dtype=torch.float32,
86
+ device_map="cpu"
87
+ ).eval()
88
+
89
+ processors["microsoft/Phi-3.5-vision-instruct"] = AutoProcessor.from_pretrained(
90
+ "microsoft/Phi-3.5-vision-instruct",
91
+ trust_remote_code=True
92
+ )
93
+ except Exception as e:
94
+ print(f"Error loading vision model: {e}")
95
 
96
+ # Chatbot tab function
97
+ @spaces.GPU()
98
+ def stream_chat(
99
+ message: str,
100
+ history: list,
101
+ system_prompt: str,
102
+ temperature: float = 0.8,
103
+ max_new_tokens: int = 1024,
104
+ top_p: float = 1.0,
105
+ top_k: int = 20,
106
+ penalty: float = 1.2,
107
+ ):
108
+ print(f'message: {message}')
109
+ print(f'history: {history}')
110
+ conversation = [{"role": "system", "content": system_prompt}]
111
+
112
+ for prompt, answer in history:
113
+ conversation.extend([
114
+ {"role": "user", "content": prompt},
115
+ {"role": "assistant", "content": answer},
116
+ ])
117
+
118
+ conversation.append({"role": "user", "content": message})
119
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
120
+
121
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
122
+ generate_kwargs = dict(
123
+ input_ids=input_ids,
124
+ max_new_tokens=max_new_tokens,
125
+ do_sample=False if temperature == 0 else True,
126
+ top_p=top_p,
127
+ top_k=top_k,
128
+ temperature=temperature,
129
+ eos_token_id=[128001,128008,128009],
130
+ streamer=streamer,
131
+ )
132
 
133
+ with torch.no_grad():
134
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
135
+ thread.start()
136
+
137
+ buffer = ""
138
+ for new_text in streamer:
139
+ buffer += new_text
140
+ yield buffer
141
 
142
+ # Vision model tab function
143
+ @spaces.GPU()
144
+ def stream_vision(image, text_input=None, model_id="microsoft/Phi-3.5-vision-instruct"):
145
+ if model_id not in models:
146
+ return "Vision model not available"
147
+
148
+ model_vision = models[model_id]
149
+ processor = processors[model_id]
150
+
151
+ # Prepare the image list and corresponding tags
152
+ images = [Image.fromarray(image).convert("RGB")]
153
+ placeholder = "<|image_1|>\n"
154
+
155
+ # Construct the prompt with the image tag and the user's text input
156
+ if text_input:
157
+ prompt_content = placeholder + text_input
158
+ else:
159
+ prompt_content = placeholder
160
+
161
+ messages = [
162
+ {"role": "user", "content": prompt_content},
163
+ ]
164
+
165
+ # Apply the chat template to the messages
166
+ prompt = processor.tokenizer.apply_chat_template(
167
+ messages, tokenize=False, add_generation_prompt=True
168
+ )
169
+
170
+ # Process the inputs with the processor
171
+ inputs = processor(prompt, images, return_tensors="pt").to(device)
172
+
173
+ # Generation parameters
174
+ generation_args = {
175
+ "max_new_tokens": 1000,
176
+ "temperature": 0.0,
177
+ "do_sample": False,
178
+ }
179
+
180
+ # Generate the response
181
+ generate_ids = model_vision.generate(
182
+ **inputs,
183
+ eos_token_id=processor.tokenizer.eos_token_id,
184
+ **generation_args
185
+ )
186
+
187
+ # Remove input tokens from the generated response
188
+ generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
189
+
190
+ # Decode the generated output
191
+ response = processor.batch_decode(
192
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
193
+ )[0]
194
+
195
+ return response
196
 
197
+ # Flask API Routes
198
+ @flask_app.route('/health', methods=['GET'])
199
+ def health_check():
200
+ return jsonify({
201
+ "status": "healthy",
202
+ "device": device,
203
+ "models_loaded": {
204
+ "chatbot": MODEL_ID1 in globals() and 'model' in globals(),
205
+ "vision": len(models) > 0
206
+ }
207
+ })
208
 
209
+ @flask_app.route('/api/chat', methods=['POST'])
210
+ def api_chat():
211
+ try:
212
+ data = request.json
213
+ message = data.get('message', '')
214
+ system_prompt = data.get('system_prompt', 'You are a helpful assistant')
215
+ temperature = data.get('temperature', 0.8)
216
+ max_new_tokens = data.get('max_new_tokens', 1024)
217
+
218
+ # Prepare conversation
219
+ conversation = [{"role": "system", "content": system_prompt}]
220
+ conversation.append({"role": "user", "content": message})
221
+
222
+ input_ids = tokenizer.apply_chat_template(
223
+ conversation, add_generation_prompt=True, return_tensors="pt"
224
+ ).to(model.device)
225
+
226
+ # Generate response
227
+ with torch.no_grad():
228
+ generate_ids = model.generate(
229
+ input_ids,
230
+ max_new_tokens=max_new_tokens,
231
+ temperature=temperature,
232
+ do_sample=temperature > 0,
233
+ eos_token_id=[128001, 128008, 128009]
234
+ )
235
+
236
+ # Decode response
237
+ response = tokenizer.decode(
238
+ generate_ids[0][input_ids.shape[1]:],
239
+ skip_special_tokens=True
240
+ )
241
+
242
+ return jsonify({
243
+ "response": response,
244
+ "device": device,
245
+ "model": MODEL_ID1
246
+ })
247
+
248
+ except Exception as e:
249
+ return jsonify({"error": str(e)}), 500
250
 
251
+ @flask_app.route('/api/vision', methods=['POST'])
252
+ def api_vision():
253
+ try:
254
+ if 'image' not in request.files:
255
+ return jsonify({"error": "No image provided"}), 400
256
+
257
+ image_file = request.files['image']
258
+ text_input = request.form.get('text_input', '')
259
+ model_id = request.form.get('model_id', 'microsoft/Phi-3.5-vision-instruct')
260
+
261
+ if model_id not in models:
262
+ return jsonify({"error": "Vision model not available"}), 400
263
+
264
+ # Process image
265
+ image = Image.open(image_file.stream).convert("RGB")
266
+
267
+ # Use the existing vision function
268
+ response = stream_vision(
269
+ image=np.array(image),
270
+ text_input=text_input,
271
+ model_id=model_id
272
+ )
273
+
274
+ return jsonify({
275
+ "response": response,
276
+ "device": device,
277
+ "model": model_id
278
+ })
279
+
280
+ except Exception as e:
281
+ return jsonify({"error": str(e)}), 500
282
 
283
+ @flask_app.route('/api/models', methods=['GET'])
284
+ def get_models():
285
+ return jsonify({
286
+ "chat_model": MODEL_ID1,
287
+ "vision_models": list(models.keys()),
288
+ "device": device
289
+ })
290
 
291
+ def run_flask():
292
+ flask_app.run(host='0.0.0.0', port=5000, debug=False, threaded=True)
293
 
294
+ def run_gradio():
295
+ # CSS for the interface
296
+ CSS = """.duplicate-button { margin: auto !important; color: white !important; background: black !important; border-radius: 100vh !important;}h3 { text-align: center;}"""
297
+ PLACEHOLDER = """<center><p>Hi! I'm your assistant. Feel free to ask your questions</p></center>"""
298
+ TITLE = "<h1><center>Phi-3.5 Chatbot & Phi-3.5 Vision</center></h1>"
299
+ EXPLANATION = """<div style="text-align: center; margin-top: 20px;"> <p>This app supports both the microsoft/Phi-3.5-mini-instruct model for chat bot and the microsoft/Phi-3.5-vision-instruct model for multimodal model.</p> <p>Phi-3.5-vision is a lightweight, state-of-the-art open multimodal model built upon datasets which include - synthetic data and filtered publicly available websites - with a focus on very high-quality, reasoning dense data both on text and vision. The model belongs to the Phi-3 model family, and the multimodal version comes with 128K context length (in tokens) it can support. The model underwent a rigorous enhancement process, incorporating both supervised fine-tuning and direct preference optimization to ensure precise instruction adherence and robust safety measures.</p> <p>Phi-3.5-mini is a lightweight, state-of-the-art open model built upon datasets used for Phi-3 - synthetic data and filtered publicly available websites - with a focus on very high-quality, reasoning dense data. The model belongs to the Phi-3 model family and supports 128K token context length. The model underwent a rigorous enhancement process, incorporating both supervised fine-tuning, proximal policy optimization, and direct preference optimization to ensure precise instruction adherence and robust safety measures.</p></div>"""
300
+ footer = """<div style="text-align: center; margin-top: 20px;"> <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> | <a href="https://github.com/arad1367" target="_blank">GitHub</a> | <a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a> | <a href="https://huggingface.co/microsoft/Phi-3.5-mini-instruct" target="_blank">microsoft/Phi-3.5-mini-instruct</a> | <a href="https://huggingface.co/microsoft/Phi-3.5-vision-instruct" target="_blank">microsoft/Phi-3.5-vision-instruct</a> <br> Made with 💖 by Pejman Ebrahimi</div>"""
301
 
302
+ # Gradio app with two tabs
303
+ with gr.Blocks(css=CSS, theme="small_and_pretty") as demo:
304
+ gr.HTML(TITLE)
305
+ gr.HTML(EXPLANATION)
306
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
307
+
308
+ with gr.Tab("Chatbot"):
309
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
310
+ gr.ChatInterface(
311
+ fn=stream_chat,
312
+ chatbot=chatbot,
313
+ fill_height=True,
314
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
315
+ additional_inputs=[
316
+ gr.Textbox(
317
+ value="You are a helpful assistant",
318
+ label="System Prompt",
319
+ render=False,
320
+ ),
321
+ gr.Slider(
322
+ minimum=0,
323
+ maximum=1,
324
+ step=0.1,
325
+ value=0.8,
326
+ label="Temperature",
327
+ render=False,
328
+ ),
329
+ gr.Slider(
330
+ minimum=128,
331
+ maximum=8192,
332
+ step=1,
333
+ value=1024,
334
+ label="Max new tokens",
335
+ render=False,
336
+ ),
337
+ gr.Slider(
338
+ minimum=0.0,
339
+ maximum=1.0,
340
+ step=0.1,
341
+ value=1.0,
342
+ label="top_p",
343
+ render=False,
344
+ ),
345
+ gr.Slider(
346
+ minimum=1,
347
+ maximum=20,
348
+ step=1,
349
+ value=20,
350
+ label="top_k",
351
+ render=False,
352
+ ),
353
+ gr.Slider(
354
+ minimum=0.0,
355
+ maximum=2.0,
356
+ step=0.1,
357
+ value=1.2,
358
+ label="Repetition penalty",
359
+ render=False,
360
+ ),
361
+ ],
362
+ examples=[
363
+ ["How to make a self-driving car?"],
364
+ ["Give me a creative idea to establish a startup"],
365
+ ["How can I improve my programming skills?"],
366
+ ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
367
+ ],
368
+ cache_examples=False,
369
+ )
370
+
371
+ with gr.Tab("Vision"):
372
+ with gr.Row():
373
+ input_img = gr.Image(label="Input Picture")
374
+ with gr.Row():
375
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="microsoft/Phi-3.5-vision-instruct")
376
+ with gr.Row():
377
+ text_input = gr.Textbox(label="Question")
378
+ with gr.Row():
379
+ submit_btn = gr.Button(value="Submit")
380
+ with gr.Row():
381
+ output_text = gr.Textbox(label="Output Text")
382
+
383
+ submit_btn.click(stream_vision, [input_img, text_input, model_selector], [output_text])
384
+
385
+ gr.HTML(footer)
386
+
387
+ # Launch the combined app
388
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
389
 
390
  if __name__ == "__main__":
391
+ # Start Flask server in a separate thread
392
+ flask_thread = threading.Thread(target=run_flask, daemon=True)
393
+ flask_thread.start()
394
+
395
+ # Run Gradio in main thread
396
+ run_gradio()