Spaces:
Sleeping
Sleeping
File size: 6,263 Bytes
8fb7337 d30672f 8fb7337 067b592 8fb7337 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from flask import Flask, request, jsonify
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
import base64
import io
import os
from werkzeug.utils import secure_filename
app = Flask(__name__)
# Global variables for model and processor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
processor = None
model = None
def load_model():
"""Load the model and processor globally"""
global processor, model
if processor is None or model is None:
print("Loading model and processor...")
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-500M-Instruct",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
).to(DEVICE)
print("Model loaded successfully!")
def process_image(image_input):
"""Process image input which can be a URL, base64 string, or file path"""
try:
if image_input.startswith('http'):
# Load image from URL
return load_image(image_input)
elif image_input.startswith('data:image'):
# Handle base64 encoded image
# Remove the data URL prefix
image_data = image_input.split(',')[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
return image
else:
# Assume it's a file path
return load_image(image_input)
except Exception as e:
raise ValueError(f"Error processing image: {str(e)}")
@app.route('/', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({
"status": "healthy",
"device": DEVICE,
"model_loaded": model is not None
})
@app.route('/chat', methods=['POST'])
def chat():
"""Main chat endpoint that accepts messages array"""
try:
# Load model if not already loaded
load_model()
# Get request data
data = request.get_json()
if not data or 'messages' not in data:
return jsonify({
"error": "Missing 'messages' field in request body"
}), 400
messages = data['messages']
if not isinstance(messages, list) or len(messages) == 0:
return jsonify({
"error": "Messages must be a non-empty array"
}), 400
# Process the last user message to extract image and text
last_message = messages[-1]
if last_message.get('role') != 'user':
return jsonify({
"error": "Last message must be from user"
}), 400
content = last_message.get('content', [])
if not isinstance(content, list):
return jsonify({
"error": "Content must be an array"
}), 400
# Extract image and text from content
image = None
text = ""
for item in content:
if item.get('type') == 'image_url' and 'image_url' in item and 'url' in item['image_url']:
image = process_image(item['image_url']['url'])
elif item.get('type') == 'text':
text = item.get('text', '')
if not image:
return jsonify({
"error": "No image found in the message"
}), 400
if not text:
return jsonify({
"error": "No text found in the message"
}), 400
# Prepare inputs for the model
model_messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text}
]
}
]
prompt = processor.apply_chat_template(model_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(DEVICE)
# Generate response
generated_ids = model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(
generated_ids,
skip_special_tokens=True,
)
# Extract the assistant's response
response_text = generated_texts[0]
# Find the assistant's response in the generated text
if "Assistant:" in response_text:
response_text = response_text.split("Assistant:")[-1].strip()
return jsonify({
"response": response_text,
"model": "SmolVLM-500M-Instruct",
"device": DEVICE
})
except Exception as e:
return jsonify({
"error": f"An error occurred: {str(e)}"
}), 500
@app.route('/upload', methods=['POST'])
def upload_image():
"""Endpoint to upload an image file"""
try:
if 'image' not in request.files:
return jsonify({
"error": "No image file provided"
}), 400
file = request.files['image']
if file.filename == '':
return jsonify({
"error": "No file selected"
}), 400
# Save the uploaded file temporarily
filename = secure_filename(file.filename)
filepath = os.path.join('/tmp', filename)
file.save(filepath)
# Convert to base64 for easy handling
with open(filepath, 'rb') as img_file:
img_data = base64.b64encode(img_file.read()).decode('utf-8')
# Clean up temporary file
os.remove(filepath)
return jsonify({
"image_data": f"data:image/jpeg;base64,{img_data}",
"filename": filename
})
except Exception as e:
return jsonify({
"error": f"An error occurred: {str(e)}"
}), 500
if __name__ == '__main__':
load_model()
app.run(host='0.0.0.0', port=7860, debug=False) |