f1 / a
water-water's picture
Create a
130ef5e verified
from flask import Flask, render_template, request, jsonify
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
import torch
from PIL import Image
import io
import base64
import uuid
from datetime import datetime
import json
import requests
from urllib3.exceptions import InsecureRequestWarning
from requests.sessions import Session
requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
# THIS IS THE CORRECTED MONKEY-PATCH BASED ON YOUR PROVIDED SIGNATURE
def patch_requests_ssl():
original_merge_environment_settings = Session.merge_environment_settings
# Match the exact signature: url, proxies, stream, verify, cert
def merge_environment_settings_no_verify(self, url, proxies, stream, verify, cert):
# Force verify to False, but still allow explicit True if passed
verify = False if verify is None else verify
# Pass all other arguments through to the original function
return original_merge_environment_settings(self, url, proxies, stream, verify, cert)
Session.merge_environment_settings = merge_environment_settings_no_verify
# Call the patch function early in your script
patch_requests_ssl()
app = Flask(__name__)
# Load model and processor
MODEL_PATH = "qwen2-7b-custom-dataset-finetuned-quanto"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_PATH, torch_dtype=DTYPE, device_map="auto", trust_remote_code=True
)
model.eval()
processor = Qwen2VLProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
# In-memory storage for chat sessions (use a database in production)
chat_sessions = {}
def run_inference(image_bytes, prompt, temperature, top_p, max_tokens):
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
messages = [
{"role": "user", "content": [
{"type": "text", "text": prompt.strip()},
{"type": "image", "image": image}
]}
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[input_text], images=[image], return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
use_cache=True
)
decoded = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
return decoded.strip().split(prompt.strip())[-1].strip()
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/chats', methods=['GET'])
def get_chats():
"""Get all chat sessions"""
chats = []
for chat_id, chat_data in chat_sessions.items():
chats.append({
'id': chat_id,
'title': chat_data.get('title', 'New Chat'),
'created_at': chat_data.get('created_at'),
'updated_at': chat_data.get('updated_at'),
'message_count': len(chat_data.get('messages', []))
})
# Sort by updated_at (most recent first)
chats.sort(key=lambda x: x['updated_at'], reverse=True)
return jsonify(chats)
@app.route('/api/chats', methods=['POST'])
def create_chat():
"""Create a new chat session"""
chat_id = str(uuid.uuid4())
now = datetime.utcnow().isoformat()
chat_sessions[chat_id] = {
'id': chat_id,
'title': 'New Chat',
'created_at': now,
'updated_at': now,
'messages': []
}
return jsonify({'id': chat_id, 'title': 'New Chat'})
@app.route('/api/chats/<chat_id>', methods=['GET'])
def get_chat(chat_id):
"""Get a specific chat session"""
if chat_id not in chat_sessions:
return jsonify({'error': 'Chat not found'}), 404
return jsonify(chat_sessions[chat_id])
@app.route('/api/chats/<chat_id>', methods=['DELETE'])
def delete_chat(chat_id):
"""Delete a chat session"""
if chat_id not in chat_sessions:
return jsonify({'error': 'Chat not found'}), 404
del chat_sessions[chat_id]
return jsonify({'success': True})
@app.route('/api/chats/<chat_id>/rename', methods=['POST'])
def rename_chat(chat_id):
"""Rename a chat session"""
if chat_id not in chat_sessions:
return jsonify({'error': 'Chat not found'}), 404
new_title = request.json.get('title', '').strip()
if not new_title:
return jsonify({'error': 'Title cannot be empty'}), 400
chat_sessions[chat_id]['title'] = new_title
chat_sessions[chat_id]['updated_at'] = datetime.utcnow().isoformat()
return jsonify({'success': True})
@app.route('/infer', methods=['POST'])
def infer():
file = request.files['image']
image_bytes = file.read()
prompt = request.form['prompt']
temperature = float(request.form['temperature'])
top_p = float(request.form['top_p'])
max_tokens = int(request.form['max_tokens'])
chat_id = request.form.get('chat_id')
# Get response from model
output = run_inference(image_bytes, prompt, temperature, top_p, max_tokens)
# Convert image to base64 for storage
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Store message in chat session
if chat_id and chat_id in chat_sessions:
now = datetime.utcnow().isoformat()
# Add user message
user_message = {
'id': str(uuid.uuid4()),
'role': 'user',
'content': prompt,
'image': image_base64,
'timestamp': now
}
# Add assistant message
assistant_message = {
'id': str(uuid.uuid4()),
'role': 'assistant',
'content': output,
'timestamp': now
}
chat_sessions[chat_id]['messages'].extend([user_message, assistant_message])
chat_sessions[chat_id]['updated_at'] = now
# Update chat title if it's the first message
if len(chat_sessions[chat_id]['messages']) == 2: # First user + assistant message
# Use first few words of the prompt as title
words = prompt.strip().split()[:4]
chat_sessions[chat_id]['title'] = ' '.join(words) + ('...' if len(words) == 4 else '')
return jsonify({'response': output})
if __name__ == "__main__":
app.run(debug=True, use_reloader=False)