Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| from transformers.image_utils import load_image | |
| import base64 | |
| import io | |
| import os | |
| from werkzeug.utils import secure_filename | |
| app = Flask(__name__) | |
| # Global variables for model and processor | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = None | |
| model = None | |
| def load_model(): | |
| """Load the model and processor globally""" | |
| global processor, model | |
| if processor is None or model is None: | |
| print("Loading model and processor...") | |
| processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct") | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| "HuggingFaceTB/SmolVLM-500M-Instruct", | |
| torch_dtype=torch.bfloat16, | |
| _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager", | |
| ).to(DEVICE) | |
| print("Model loaded successfully!") | |
| def process_image(image_input): | |
| """Process image input which can be a URL, base64 string, or file path""" | |
| try: | |
| if image_input.startswith('http'): | |
| # Load image from URL | |
| return load_image(image_input) | |
| elif image_input.startswith('data:image'): | |
| # Handle base64 encoded image | |
| # Remove the data URL prefix | |
| image_data = image_input.split(',')[1] | |
| image_bytes = base64.b64decode(image_data) | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| return image | |
| else: | |
| # Assume it's a file path | |
| return load_image(image_input) | |
| except Exception as e: | |
| raise ValueError(f"Error processing image: {str(e)}") | |
| def health_check(): | |
| """Health check endpoint""" | |
| return jsonify({ | |
| "status": "healthy", | |
| "device": DEVICE, | |
| "model_loaded": model is not None | |
| }) | |
| def chat(): | |
| """Main chat endpoint that accepts messages array""" | |
| try: | |
| # Load model if not already loaded | |
| load_model() | |
| # Get request data | |
| data = request.get_json() | |
| if not data or 'messages' not in data: | |
| return jsonify({ | |
| "error": "Missing 'messages' field in request body" | |
| }), 400 | |
| messages = data['messages'] | |
| if not isinstance(messages, list) or len(messages) == 0: | |
| return jsonify({ | |
| "error": "Messages must be a non-empty array" | |
| }), 400 | |
| # Process the last user message to extract image and text | |
| last_message = messages[-1] | |
| if last_message.get('role') != 'user': | |
| return jsonify({ | |
| "error": "Last message must be from user" | |
| }), 400 | |
| content = last_message.get('content', []) | |
| if not isinstance(content, list): | |
| return jsonify({ | |
| "error": "Content must be an array" | |
| }), 400 | |
| # Extract image and text from content | |
| image = None | |
| text = "" | |
| for item in content: | |
| if item.get('type') == 'image_url' and 'image_url' in item and 'url' in item['image_url']: | |
| image = process_image(item['image_url']['url']) | |
| elif item.get('type') == 'text': | |
| text = item.get('text', '') | |
| if not image: | |
| return jsonify({ | |
| "error": "No image found in the message" | |
| }), 400 | |
| if not text: | |
| return jsonify({ | |
| "error": "No text found in the message" | |
| }), 400 | |
| # Prepare inputs for the model | |
| model_messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": text} | |
| ] | |
| } | |
| ] | |
| prompt = processor.apply_chat_template(model_messages, add_generation_prompt=True) | |
| inputs = processor(text=prompt, images=[image], return_tensors="pt") | |
| inputs = inputs.to(DEVICE) | |
| # Generate response | |
| generated_ids = model.generate(**inputs, max_new_tokens=500) | |
| generated_texts = processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| ) | |
| # Extract the assistant's response | |
| response_text = generated_texts[0] | |
| # Find the assistant's response in the generated text | |
| if "Assistant:" in response_text: | |
| response_text = response_text.split("Assistant:")[-1].strip() | |
| return jsonify({ | |
| "response": response_text, | |
| "model": "SmolVLM-500M-Instruct", | |
| "device": DEVICE | |
| }) | |
| except Exception as e: | |
| return jsonify({ | |
| "error": f"An error occurred: {str(e)}" | |
| }), 500 | |
| def upload_image(): | |
| """Endpoint to upload an image file""" | |
| try: | |
| if 'image' not in request.files: | |
| return jsonify({ | |
| "error": "No image file provided" | |
| }), 400 | |
| file = request.files['image'] | |
| if file.filename == '': | |
| return jsonify({ | |
| "error": "No file selected" | |
| }), 400 | |
| # Save the uploaded file temporarily | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join('/tmp', filename) | |
| file.save(filepath) | |
| # Convert to base64 for easy handling | |
| with open(filepath, 'rb') as img_file: | |
| img_data = base64.b64encode(img_file.read()).decode('utf-8') | |
| # Clean up temporary file | |
| os.remove(filepath) | |
| return jsonify({ | |
| "image_data": f"data:image/jpeg;base64,{img_data}", | |
| "filename": filename | |
| }) | |
| except Exception as e: | |
| return jsonify({ | |
| "error": f"An error occurred: {str(e)}" | |
| }), 500 | |
| if __name__ == '__main__': | |
| load_model() | |
| app.run(host='0.0.0.0', port=7860, debug=False) |