| from flask import Flask, request, jsonify | |
| from handler import EndpointHandler | |
| import torch | |
| app = Flask(__name__) | |
| # Initialize the handler | |
| handler = EndpointHandler() | |
| def predict(): | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file provided'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| # Read the file bytes | |
| image_bytes = file.read() | |
| # Get point prompts if provided | |
| point_coords = request.form.get('point_coords') | |
| point_labels = request.form.get('point_labels') | |
| # Process with handler | |
| try: | |
| if point_coords and point_labels: | |
| # Convert string inputs to lists | |
| point_coords = eval(point_coords) # e.g. "[[500, 375]]" | |
| point_labels = eval(point_labels) # e.g. "[1]" | |
| result = handler({ | |
| 'image': image_bytes, | |
| 'point_coords': point_coords, | |
| 'point_labels': point_labels | |
| }) | |
| else: | |
| result = handler(image_bytes) | |
| return jsonify(result) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(debug=True, port=5000) |