Spaces:
Running
Running
| from flask import Flask, render_template, request, jsonify | |
| from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration | |
| import torch | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import uuid | |
| from datetime import datetime | |
| import json | |
| import requests | |
| from urllib3.exceptions import InsecureRequestWarning | |
| from requests.sessions import Session | |
| requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) | |
| # THIS IS THE CORRECTED MONKEY-PATCH BASED ON YOUR PROVIDED SIGNATURE | |
| def patch_requests_ssl(): | |
| original_merge_environment_settings = Session.merge_environment_settings | |
| # Match the exact signature: url, proxies, stream, verify, cert | |
| def merge_environment_settings_no_verify(self, url, proxies, stream, verify, cert): | |
| # Force verify to False, but still allow explicit True if passed | |
| verify = False if verify is None else verify | |
| # Pass all other arguments through to the original function | |
| return original_merge_environment_settings(self, url, proxies, stream, verify, cert) | |
| Session.merge_environment_settings = merge_environment_settings_no_verify | |
| # Call the patch function early in your script | |
| patch_requests_ssl() | |
| app = Flask(__name__) | |
| # Load model and processor | |
| MODEL_PATH = "qwen2-7b-custom-dataset-finetuned-quanto" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, torch_dtype=DTYPE, device_map="auto", trust_remote_code=True | |
| ) | |
| model.eval() | |
| processor = Qwen2VLProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
| # In-memory storage for chat sessions (use a database in production) | |
| chat_sessions = {} | |
| def run_inference(image_bytes, prompt, temperature, top_p, max_tokens): | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| messages = [ | |
| {"role": "user", "content": [ | |
| {"type": "text", "text": prompt.strip()}, | |
| {"type": "image", "image": image} | |
| ]} | |
| ] | |
| input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=[input_text], images=[image], return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| use_cache=True | |
| ) | |
| decoded = processor.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return decoded.strip().split(prompt.strip())[-1].strip() | |
| def index(): | |
| return render_template('index.html') | |
| def get_chats(): | |
| """Get all chat sessions""" | |
| chats = [] | |
| for chat_id, chat_data in chat_sessions.items(): | |
| chats.append({ | |
| 'id': chat_id, | |
| 'title': chat_data.get('title', 'New Chat'), | |
| 'created_at': chat_data.get('created_at'), | |
| 'updated_at': chat_data.get('updated_at'), | |
| 'message_count': len(chat_data.get('messages', [])) | |
| }) | |
| # Sort by updated_at (most recent first) | |
| chats.sort(key=lambda x: x['updated_at'], reverse=True) | |
| return jsonify(chats) | |
| def create_chat(): | |
| """Create a new chat session""" | |
| chat_id = str(uuid.uuid4()) | |
| now = datetime.utcnow().isoformat() | |
| chat_sessions[chat_id] = { | |
| 'id': chat_id, | |
| 'title': 'New Chat', | |
| 'created_at': now, | |
| 'updated_at': now, | |
| 'messages': [] | |
| } | |
| return jsonify({'id': chat_id, 'title': 'New Chat'}) | |
| def get_chat(chat_id): | |
| """Get a specific chat session""" | |
| if chat_id not in chat_sessions: | |
| return jsonify({'error': 'Chat not found'}), 404 | |
| return jsonify(chat_sessions[chat_id]) | |
| def delete_chat(chat_id): | |
| """Delete a chat session""" | |
| if chat_id not in chat_sessions: | |
| return jsonify({'error': 'Chat not found'}), 404 | |
| del chat_sessions[chat_id] | |
| return jsonify({'success': True}) | |
| def rename_chat(chat_id): | |
| """Rename a chat session""" | |
| if chat_id not in chat_sessions: | |
| return jsonify({'error': 'Chat not found'}), 404 | |
| new_title = request.json.get('title', '').strip() | |
| if not new_title: | |
| return jsonify({'error': 'Title cannot be empty'}), 400 | |
| chat_sessions[chat_id]['title'] = new_title | |
| chat_sessions[chat_id]['updated_at'] = datetime.utcnow().isoformat() | |
| return jsonify({'success': True}) | |
| def infer(): | |
| file = request.files['image'] | |
| image_bytes = file.read() | |
| prompt = request.form['prompt'] | |
| temperature = float(request.form['temperature']) | |
| top_p = float(request.form['top_p']) | |
| max_tokens = int(request.form['max_tokens']) | |
| chat_id = request.form.get('chat_id') | |
| # Get response from model | |
| output = run_inference(image_bytes, prompt, temperature, top_p, max_tokens) | |
| # Convert image to base64 for storage | |
| image_base64 = base64.b64encode(image_bytes).decode('utf-8') | |
| # Store message in chat session | |
| if chat_id and chat_id in chat_sessions: | |
| now = datetime.utcnow().isoformat() | |
| # Add user message | |
| user_message = { | |
| 'id': str(uuid.uuid4()), | |
| 'role': 'user', | |
| 'content': prompt, | |
| 'image': image_base64, | |
| 'timestamp': now | |
| } | |
| # Add assistant message | |
| assistant_message = { | |
| 'id': str(uuid.uuid4()), | |
| 'role': 'assistant', | |
| 'content': output, | |
| 'timestamp': now | |
| } | |
| chat_sessions[chat_id]['messages'].extend([user_message, assistant_message]) | |
| chat_sessions[chat_id]['updated_at'] = now | |
| # Update chat title if it's the first message | |
| if len(chat_sessions[chat_id]['messages']) == 2: # First user + assistant message | |
| # Use first few words of the prompt as title | |
| words = prompt.strip().split()[:4] | |
| chat_sessions[chat_id]['title'] = ' '.join(words) + ('...' if len(words) == 4 else '') | |
| return jsonify({'response': output}) | |
| if __name__ == "__main__": | |
| app.run(debug=True, use_reloader=False) |