water-classification-api / flask_api.py
durgaprasad143's picture
Upload flask_api.py with huggingface_hub
757208b verified
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import DeiTImageProcessor, DeiTForImageClassification
from PIL import Image
import io
import base64
import json
import numpy as np
app = Flask(__name__)
CORS(app) # Enable CORS for Flutter web support
class WaterClassificationModel:
def __init__(self, model_id='durgaprasad143/water-classification-deit'):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"πŸ”„ Loading model from Hub: {model_id}...")
try:
self.processor = DeiTImageProcessor.from_pretrained(model_id)
self.model = DeiTForImageClassification.from_pretrained(model_id)
self.model.to(self.device)
self.model.eval()
print(f"βœ… Model loaded from {model_id}")
except Exception as e:
print(f"❌ Failed to load model: {e}")
raise e
def preprocess_image(self, image_bytes):
"""Preprocess image for model input using DeiT processor"""
# Open image from bytes
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Use the DeiT processor (same as training)
inputs = self.processor(images=image, return_tensors="pt")
return inputs['pixel_values'].to(self.device)
def predict(self, image_bytes):
"""Make prediction on image"""
try:
# Preprocess image
input_tensor = self.preprocess_image(image_bytes)
# Make prediction
with torch.no_grad():
outputs = self.model(input_tensor).logits
probabilities = torch.nn.functional.softmax(outputs, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
# Map to labels
class_names = ['hazardous', 'non_hazardous']
prediction = class_names[predicted_class]
return {
'prediction': prediction,
'confidence': confidence,
'probabilities': probabilities[0].cpu().numpy().tolist()
}
except Exception as e:
print(f"❌ Prediction error: {e}")
return {
'error': str(e),
'prediction': 'unknown',
'confidence': 0.0
}
# Initialize model
model = WaterClassificationModel()
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({'status': 'healthy', 'model_loaded': True})
@app.route('/predict', methods=['POST'])
def predict():
"""Prediction endpoint"""
try:
# Get image from request
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
image_file = request.files['image']
image_bytes = image_file.read()
if not image_bytes:
return jsonify({'error': 'Empty image'}), 400
# Make prediction
result = model.predict(image_bytes)
return jsonify(result)
except Exception as e:
print(f"❌ API Error: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/predict_base64', methods=['POST'])
def predict_base64():
"""Prediction endpoint for base64 encoded images"""
try:
data = request.get_json()
if not data or 'image' not in data:
return jsonify({'error': 'No image provided'}), 400
# Decode base64 image
image_data = data['image']
if ',' in image_data:
image_data = image_data.split(',')[1] # Remove data URL prefix
image_bytes = base64.b64decode(image_data)
# Make prediction
result = model.predict(image_bytes)
return jsonify(result)
except Exception as e:
print(f"❌ API Error: {e}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
print("πŸš€ Starting Water Classification API...")
print("πŸ“‘ Available endpoints:")
print(" GET /health")
print(" POST /predict (multipart form data)")
print(" POST /predict_base64 (JSON with base64 image)")
print("🌐 Server running on http://localhost:5000")
app.run(host='0.0.0.0', port=5000, debug=True)