from flask import Flask, request, jsonify import torch from PIL import Image from transformers import AutoProcessor, AutoModelForVision2Seq from transformers.image_utils import load_image import base64 import io import os from werkzeug.utils import secure_filename app = Flask(__name__) # Global variables for model and processor DEVICE = "cuda" if torch.cuda.is_available() else "cpu" processor = None model = None def load_model(): """Load the model and processor globally""" global processor, model if processor is None or model is None: print("Loading model and processor...") processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct") model = AutoModelForVision2Seq.from_pretrained( "HuggingFaceTB/SmolVLM-500M-Instruct", torch_dtype=torch.bfloat16, _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager", ).to(DEVICE) print("Model loaded successfully!") def process_image(image_input): """Process image input which can be a URL, base64 string, or file path""" try: if image_input.startswith('http'): # Load image from URL return load_image(image_input) elif image_input.startswith('data:image'): # Handle base64 encoded image # Remove the data URL prefix image_data = image_input.split(',')[1] image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)) return image else: # Assume it's a file path return load_image(image_input) except Exception as e: raise ValueError(f"Error processing image: {str(e)}") @app.route('/', methods=['GET']) def health_check(): """Health check endpoint""" return jsonify({ "status": "healthy", "device": DEVICE, "model_loaded": model is not None }) @app.route('/chat', methods=['POST']) def chat(): """Main chat endpoint that accepts messages array""" try: # Load model if not already loaded load_model() # Get request data data = request.get_json() if not data or 'messages' not in data: return jsonify({ "error": "Missing 'messages' field in request body" }), 400 messages = data['messages'] if not isinstance(messages, list) or len(messages) == 0: return jsonify({ "error": "Messages must be a non-empty array" }), 400 # Process the last user message to extract image and text last_message = messages[-1] if last_message.get('role') != 'user': return jsonify({ "error": "Last message must be from user" }), 400 content = last_message.get('content', []) if not isinstance(content, list): return jsonify({ "error": "Content must be an array" }), 400 # Extract image and text from content image = None text = "" for item in content: if item.get('type') == 'image_url' and 'image_url' in item and 'url' in item['image_url']: image = process_image(item['image_url']['url']) elif item.get('type') == 'text': text = item.get('text', '') if not image: return jsonify({ "error": "No image found in the message" }), 400 if not text: return jsonify({ "error": "No text found in the message" }), 400 # Prepare inputs for the model model_messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(model_messages, add_generation_prompt=True) inputs = processor(text=prompt, images=[image], return_tensors="pt") inputs = inputs.to(DEVICE) # Generate response generated_ids = model.generate(**inputs, max_new_tokens=500) generated_texts = processor.batch_decode( generated_ids, skip_special_tokens=True, ) # Extract the assistant's response response_text = generated_texts[0] # Find the assistant's response in the generated text if "Assistant:" in response_text: response_text = response_text.split("Assistant:")[-1].strip() return jsonify({ "response": response_text, "model": "SmolVLM-500M-Instruct", "device": DEVICE }) except Exception as e: return jsonify({ "error": f"An error occurred: {str(e)}" }), 500 @app.route('/upload', methods=['POST']) def upload_image(): """Endpoint to upload an image file""" try: if 'image' not in request.files: return jsonify({ "error": "No image file provided" }), 400 file = request.files['image'] if file.filename == '': return jsonify({ "error": "No file selected" }), 400 # Save the uploaded file temporarily filename = secure_filename(file.filename) filepath = os.path.join('/tmp', filename) file.save(filepath) # Convert to base64 for easy handling with open(filepath, 'rb') as img_file: img_data = base64.b64encode(img_file.read()).decode('utf-8') # Clean up temporary file os.remove(filepath) return jsonify({ "image_data": f"data:image/jpeg;base64,{img_data}", "filename": filename }) except Exception as e: return jsonify({ "error": f"An error occurred: {str(e)}" }), 500 if __name__ == '__main__': load_model() app.run(host='0.0.0.0', port=7860, debug=False)