from PIL import Image import torch import joblib import numpy as np from transformers import CLIPProcessor, CLIPModel from config import DEVICE, MODEL_SAVE_PATH from flask import Flask, request, jsonify from flask_cors import CORS import os app = Flask(__name__) CORS(app) clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") def predict_image(image_path): image = Image.open(image_path).convert("RGB") inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE) with torch.no_grad(): image_features = clip_model.get_image_features(**inputs) image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) features = image_features.cpu().numpy() model = joblib.load(MODEL_SAVE_PATH) label_encoder = joblib.load("label_encoder.joblib") pred = model.predict(features) label = label_encoder.inverse_transform(pred) return label[0] @app.route('/predict', methods=['POST']) def predict(): if 'image' not in request.files: return jsonify({'error': 'No image uploaded'}), 400 image = request.files['image'] if image.filename == '': return jsonify({'error': 'No image selected'}), 400 try: # Save the uploaded image temporarily image_path = "temp_image.jpg" image.save(image_path) # Predict the image prediction = predict_image(image_path) # Remove the temporary image os.remove(image_path) return jsonify({'prediction': prediction}) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/healthcheck', methods=['GET']) def healthcheck(): return jsonify({'status': 'ok'}), 200 if __name__ == '__main__': port = int(os.environ.get('PORT', 5000)) app.run(debug=True, host='0.0.0.0', port=port)