|
|
from flask import Flask, request, jsonify
|
|
|
from flask_cors import CORS
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torchvision import transforms
|
|
|
from transformers import DeiTImageProcessor, DeiTForImageClassification
|
|
|
from PIL import Image
|
|
|
import io
|
|
|
import base64
|
|
|
import json
|
|
|
import numpy as np
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
CORS(app)
|
|
|
|
|
|
class WaterClassificationModel:
|
|
|
def __init__(self, model_id='durgaprasad143/water-classification-deit'):
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
print(f"π Loading model from Hub: {model_id}...")
|
|
|
|
|
|
try:
|
|
|
self.processor = DeiTImageProcessor.from_pretrained(model_id)
|
|
|
self.model = DeiTForImageClassification.from_pretrained(model_id)
|
|
|
self.model.to(self.device)
|
|
|
self.model.eval()
|
|
|
print(f"β
Model loaded from {model_id}")
|
|
|
except Exception as e:
|
|
|
print(f"β Failed to load model: {e}")
|
|
|
raise e
|
|
|
|
|
|
def preprocess_image(self, image_bytes):
|
|
|
"""Preprocess image for model input using DeiT processor"""
|
|
|
|
|
|
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
|
|
|
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt")
|
|
|
return inputs['pixel_values'].to(self.device)
|
|
|
|
|
|
def predict(self, image_bytes):
|
|
|
"""Make prediction on image"""
|
|
|
try:
|
|
|
|
|
|
input_tensor = self.preprocess_image(image_bytes)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(input_tensor).logits
|
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
|
|
predicted_class = torch.argmax(probabilities, dim=1).item()
|
|
|
confidence = probabilities[0][predicted_class].item()
|
|
|
|
|
|
|
|
|
class_names = ['hazardous', 'non_hazardous']
|
|
|
prediction = class_names[predicted_class]
|
|
|
|
|
|
return {
|
|
|
'prediction': prediction,
|
|
|
'confidence': confidence,
|
|
|
'probabilities': probabilities[0].cpu().numpy().tolist()
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"β Prediction error: {e}")
|
|
|
return {
|
|
|
'error': str(e),
|
|
|
'prediction': 'unknown',
|
|
|
'confidence': 0.0
|
|
|
}
|
|
|
|
|
|
|
|
|
model = WaterClassificationModel()
|
|
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
|
def health_check():
|
|
|
"""Health check endpoint"""
|
|
|
return jsonify({'status': 'healthy', 'model_loaded': True})
|
|
|
|
|
|
@app.route('/predict', methods=['POST'])
|
|
|
def predict():
|
|
|
"""Prediction endpoint"""
|
|
|
try:
|
|
|
|
|
|
if 'image' not in request.files:
|
|
|
return jsonify({'error': 'No image provided'}), 400
|
|
|
|
|
|
image_file = request.files['image']
|
|
|
image_bytes = image_file.read()
|
|
|
|
|
|
if not image_bytes:
|
|
|
return jsonify({'error': 'Empty image'}), 400
|
|
|
|
|
|
|
|
|
result = model.predict(image_bytes)
|
|
|
|
|
|
return jsonify(result)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"β API Error: {e}")
|
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/predict_base64', methods=['POST'])
|
|
|
def predict_base64():
|
|
|
"""Prediction endpoint for base64 encoded images"""
|
|
|
try:
|
|
|
data = request.get_json()
|
|
|
|
|
|
if not data or 'image' not in data:
|
|
|
return jsonify({'error': 'No image provided'}), 400
|
|
|
|
|
|
|
|
|
image_data = data['image']
|
|
|
if ',' in image_data:
|
|
|
image_data = image_data.split(',')[1]
|
|
|
|
|
|
image_bytes = base64.b64decode(image_data)
|
|
|
|
|
|
|
|
|
result = model.predict(image_bytes)
|
|
|
|
|
|
return jsonify(result)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"β API Error: {e}")
|
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
print("π Starting Water Classification API...")
|
|
|
print("π‘ Available endpoints:")
|
|
|
print(" GET /health")
|
|
|
print(" POST /predict (multipart form data)")
|
|
|
print(" POST /predict_base64 (JSON with base64 image)")
|
|
|
print("π Server running on http://localhost:5000")
|
|
|
|
|
|
app.run(host='0.0.0.0', port=5000, debug=True)
|
|
|
|