clip-model / predict.py
JonSnow1512's picture
Upload 5 files
d8a1c8d verified
from PIL import Image
import torch
import joblib
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from config import DEVICE, MODEL_SAVE_PATH
from flask import Flask, request, jsonify
from flask_cors import CORS
import os
app = Flask(__name__)
CORS(app)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def predict_image(image_path):
image = Image.open(image_path).convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
image_features = clip_model.get_image_features(**inputs)
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
features = image_features.cpu().numpy()
model = joblib.load(MODEL_SAVE_PATH)
label_encoder = joblib.load("label_encoder.joblib")
pred = model.predict(features)
label = label_encoder.inverse_transform(pred)
return label[0]
@app.route('/predict', methods=['POST'])
def predict():
if 'image' not in request.files:
return jsonify({'error': 'No image uploaded'}), 400
image = request.files['image']
if image.filename == '':
return jsonify({'error': 'No image selected'}), 400
try:
# Save the uploaded image temporarily
image_path = "temp_image.jpg"
image.save(image_path)
# Predict the image
prediction = predict_image(image_path)
# Remove the temporary image
os.remove(image_path)
return jsonify({'prediction': prediction})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/healthcheck', methods=['GET'])
def healthcheck():
return jsonify({'status': 'ok'}), 200
if __name__ == '__main__':
port = int(os.environ.get('PORT', 5000))
app.run(debug=True, host='0.0.0.0', port=port)