water-water commited on
Commit
130ef5e
·
verified ·
1 Parent(s): abc9f86
Files changed (1) hide show
  1. a +193 -0
a ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
3
+ import torch
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+ import uuid
8
+ from datetime import datetime
9
+ import json
10
+
11
+ import requests
12
+ from urllib3.exceptions import InsecureRequestWarning
13
+ from requests.sessions import Session
14
+
15
+ requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
16
+
17
+ # THIS IS THE CORRECTED MONKEY-PATCH BASED ON YOUR PROVIDED SIGNATURE
18
+ def patch_requests_ssl():
19
+ original_merge_environment_settings = Session.merge_environment_settings
20
+
21
+ # Match the exact signature: url, proxies, stream, verify, cert
22
+ def merge_environment_settings_no_verify(self, url, proxies, stream, verify, cert):
23
+ # Force verify to False, but still allow explicit True if passed
24
+ verify = False if verify is None else verify
25
+ # Pass all other arguments through to the original function
26
+ return original_merge_environment_settings(self, url, proxies, stream, verify, cert)
27
+
28
+ Session.merge_environment_settings = merge_environment_settings_no_verify
29
+
30
+ # Call the patch function early in your script
31
+ patch_requests_ssl()
32
+
33
+ app = Flask(__name__)
34
+
35
+ # Load model and processor
36
+ MODEL_PATH = "qwen2-7b-custom-dataset-finetuned-quanto"
37
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
39
+
40
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
41
+ MODEL_PATH, torch_dtype=DTYPE, device_map="auto", trust_remote_code=True
42
+ )
43
+ model.eval()
44
+ processor = Qwen2VLProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
45
+
46
+ # In-memory storage for chat sessions (use a database in production)
47
+ chat_sessions = {}
48
+
49
+ def run_inference(image_bytes, prompt, temperature, top_p, max_tokens):
50
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
51
+
52
+ messages = [
53
+ {"role": "user", "content": [
54
+ {"type": "text", "text": prompt.strip()},
55
+ {"type": "image", "image": image}
56
+ ]}
57
+ ]
58
+
59
+ input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
60
+ inputs = processor(text=[input_text], images=[image], return_tensors="pt").to(model.device)
61
+
62
+ with torch.no_grad():
63
+ outputs = model.generate(
64
+ **inputs,
65
+ max_new_tokens=max_tokens,
66
+ temperature=temperature,
67
+ top_p=top_p,
68
+ do_sample=True,
69
+ use_cache=True
70
+ )
71
+
72
+ decoded = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
73
+ return decoded.strip().split(prompt.strip())[-1].strip()
74
+
75
+ @app.route('/')
76
+ def index():
77
+ return render_template('index.html')
78
+
79
+ @app.route('/api/chats', methods=['GET'])
80
+ def get_chats():
81
+ """Get all chat sessions"""
82
+ chats = []
83
+ for chat_id, chat_data in chat_sessions.items():
84
+ chats.append({
85
+ 'id': chat_id,
86
+ 'title': chat_data.get('title', 'New Chat'),
87
+ 'created_at': chat_data.get('created_at'),
88
+ 'updated_at': chat_data.get('updated_at'),
89
+ 'message_count': len(chat_data.get('messages', []))
90
+ })
91
+
92
+ # Sort by updated_at (most recent first)
93
+ chats.sort(key=lambda x: x['updated_at'], reverse=True)
94
+ return jsonify(chats)
95
+
96
+ @app.route('/api/chats', methods=['POST'])
97
+ def create_chat():
98
+ """Create a new chat session"""
99
+ chat_id = str(uuid.uuid4())
100
+ now = datetime.utcnow().isoformat()
101
+
102
+ chat_sessions[chat_id] = {
103
+ 'id': chat_id,
104
+ 'title': 'New Chat',
105
+ 'created_at': now,
106
+ 'updated_at': now,
107
+ 'messages': []
108
+ }
109
+
110
+ return jsonify({'id': chat_id, 'title': 'New Chat'})
111
+
112
+ @app.route('/api/chats/<chat_id>', methods=['GET'])
113
+ def get_chat(chat_id):
114
+ """Get a specific chat session"""
115
+ if chat_id not in chat_sessions:
116
+ return jsonify({'error': 'Chat not found'}), 404
117
+
118
+ return jsonify(chat_sessions[chat_id])
119
+
120
+ @app.route('/api/chats/<chat_id>', methods=['DELETE'])
121
+ def delete_chat(chat_id):
122
+ """Delete a chat session"""
123
+ if chat_id not in chat_sessions:
124
+ return jsonify({'error': 'Chat not found'}), 404
125
+
126
+ del chat_sessions[chat_id]
127
+ return jsonify({'success': True})
128
+
129
+ @app.route('/api/chats/<chat_id>/rename', methods=['POST'])
130
+ def rename_chat(chat_id):
131
+ """Rename a chat session"""
132
+ if chat_id not in chat_sessions:
133
+ return jsonify({'error': 'Chat not found'}), 404
134
+
135
+ new_title = request.json.get('title', '').strip()
136
+ if not new_title:
137
+ return jsonify({'error': 'Title cannot be empty'}), 400
138
+
139
+ chat_sessions[chat_id]['title'] = new_title
140
+ chat_sessions[chat_id]['updated_at'] = datetime.utcnow().isoformat()
141
+
142
+ return jsonify({'success': True})
143
+
144
+ @app.route('/infer', methods=['POST'])
145
+ def infer():
146
+ file = request.files['image']
147
+ image_bytes = file.read()
148
+ prompt = request.form['prompt']
149
+ temperature = float(request.form['temperature'])
150
+ top_p = float(request.form['top_p'])
151
+ max_tokens = int(request.form['max_tokens'])
152
+ chat_id = request.form.get('chat_id')
153
+
154
+ # Get response from model
155
+ output = run_inference(image_bytes, prompt, temperature, top_p, max_tokens)
156
+
157
+ # Convert image to base64 for storage
158
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
159
+
160
+ # Store message in chat session
161
+ if chat_id and chat_id in chat_sessions:
162
+ now = datetime.utcnow().isoformat()
163
+
164
+ # Add user message
165
+ user_message = {
166
+ 'id': str(uuid.uuid4()),
167
+ 'role': 'user',
168
+ 'content': prompt,
169
+ 'image': image_base64,
170
+ 'timestamp': now
171
+ }
172
+
173
+ # Add assistant message
174
+ assistant_message = {
175
+ 'id': str(uuid.uuid4()),
176
+ 'role': 'assistant',
177
+ 'content': output,
178
+ 'timestamp': now
179
+ }
180
+
181
+ chat_sessions[chat_id]['messages'].extend([user_message, assistant_message])
182
+ chat_sessions[chat_id]['updated_at'] = now
183
+
184
+ # Update chat title if it's the first message
185
+ if len(chat_sessions[chat_id]['messages']) == 2: # First user + assistant message
186
+ # Use first few words of the prompt as title
187
+ words = prompt.strip().split()[:4]
188
+ chat_sessions[chat_id]['title'] = ' '.join(words) + ('...' if len(words) == 4 else '')
189
+
190
+ return jsonify({'response': output})
191
+
192
+ if __name__ == "__main__":
193
+ app.run(debug=True, use_reloader=False)