image / app.py
aliroohan179's picture
Update app.py
067b592 verified
from flask import Flask, request, jsonify
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
import base64
import io
import os
from werkzeug.utils import secure_filename
app = Flask(__name__)
# Global variables for model and processor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
processor = None
model = None
def load_model():
"""Load the model and processor globally"""
global processor, model
if processor is None or model is None:
print("Loading model and processor...")
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-500M-Instruct",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
).to(DEVICE)
print("Model loaded successfully!")
def process_image(image_input):
"""Process image input which can be a URL, base64 string, or file path"""
try:
if image_input.startswith('http'):
# Load image from URL
return load_image(image_input)
elif image_input.startswith('data:image'):
# Handle base64 encoded image
# Remove the data URL prefix
image_data = image_input.split(',')[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
return image
else:
# Assume it's a file path
return load_image(image_input)
except Exception as e:
raise ValueError(f"Error processing image: {str(e)}")
@app.route('/', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({
"status": "healthy",
"device": DEVICE,
"model_loaded": model is not None
})
@app.route('/chat', methods=['POST'])
def chat():
"""Main chat endpoint that accepts messages array"""
try:
# Load model if not already loaded
load_model()
# Get request data
data = request.get_json()
if not data or 'messages' not in data:
return jsonify({
"error": "Missing 'messages' field in request body"
}), 400
messages = data['messages']
if not isinstance(messages, list) or len(messages) == 0:
return jsonify({
"error": "Messages must be a non-empty array"
}), 400
# Process the last user message to extract image and text
last_message = messages[-1]
if last_message.get('role') != 'user':
return jsonify({
"error": "Last message must be from user"
}), 400
content = last_message.get('content', [])
if not isinstance(content, list):
return jsonify({
"error": "Content must be an array"
}), 400
# Extract image and text from content
image = None
text = ""
for item in content:
if item.get('type') == 'image_url' and 'image_url' in item and 'url' in item['image_url']:
image = process_image(item['image_url']['url'])
elif item.get('type') == 'text':
text = item.get('text', '')
if not image:
return jsonify({
"error": "No image found in the message"
}), 400
if not text:
return jsonify({
"error": "No text found in the message"
}), 400
# Prepare inputs for the model
model_messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text}
]
}
]
prompt = processor.apply_chat_template(model_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(DEVICE)
# Generate response
generated_ids = model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(
generated_ids,
skip_special_tokens=True,
)
# Extract the assistant's response
response_text = generated_texts[0]
# Find the assistant's response in the generated text
if "Assistant:" in response_text:
response_text = response_text.split("Assistant:")[-1].strip()
return jsonify({
"response": response_text,
"model": "SmolVLM-500M-Instruct",
"device": DEVICE
})
except Exception as e:
return jsonify({
"error": f"An error occurred: {str(e)}"
}), 500
@app.route('/upload', methods=['POST'])
def upload_image():
"""Endpoint to upload an image file"""
try:
if 'image' not in request.files:
return jsonify({
"error": "No image file provided"
}), 400
file = request.files['image']
if file.filename == '':
return jsonify({
"error": "No file selected"
}), 400
# Save the uploaded file temporarily
filename = secure_filename(file.filename)
filepath = os.path.join('/tmp', filename)
file.save(filepath)
# Convert to base64 for easy handling
with open(filepath, 'rb') as img_file:
img_data = base64.b64encode(img_file.read()).decode('utf-8')
# Clean up temporary file
os.remove(filepath)
return jsonify({
"image_data": f"data:image/jpeg;base64,{img_data}",
"filename": filename
})
except Exception as e:
return jsonify({
"error": f"An error occurred: {str(e)}"
}), 500
if __name__ == '__main__':
load_model()
app.run(host='0.0.0.0', port=7860, debug=False)