from flask import Flask, request, jsonify from flask_cors import CORS import torch import torch.nn as nn from torchvision import transforms from transformers import DeiTImageProcessor, DeiTForImageClassification from PIL import Image import io import base64 import json import numpy as np app = Flask(__name__) CORS(app) # Enable CORS for Flutter web support class WaterClassificationModel: def __init__(self, model_id='durgaprasad143/water-classification-deit'): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"🔄 Loading model from Hub: {model_id}...") try: self.processor = DeiTImageProcessor.from_pretrained(model_id) self.model = DeiTForImageClassification.from_pretrained(model_id) self.model.to(self.device) self.model.eval() print(f"✅ Model loaded from {model_id}") except Exception as e: print(f"❌ Failed to load model: {e}") raise e def preprocess_image(self, image_bytes): """Preprocess image for model input using DeiT processor""" # Open image from bytes image = Image.open(io.BytesIO(image_bytes)).convert('RGB') # Use the DeiT processor (same as training) inputs = self.processor(images=image, return_tensors="pt") return inputs['pixel_values'].to(self.device) def predict(self, image_bytes): """Make prediction on image""" try: # Preprocess image input_tensor = self.preprocess_image(image_bytes) # Make prediction with torch.no_grad(): outputs = self.model(input_tensor).logits probabilities = torch.nn.functional.softmax(outputs, dim=1) predicted_class = torch.argmax(probabilities, dim=1).item() confidence = probabilities[0][predicted_class].item() # Map to labels class_names = ['hazardous', 'non_hazardous'] prediction = class_names[predicted_class] return { 'prediction': prediction, 'confidence': confidence, 'probabilities': probabilities[0].cpu().numpy().tolist() } except Exception as e: print(f"❌ Prediction error: {e}") return { 'error': str(e), 'prediction': 'unknown', 'confidence': 0.0 } # Initialize model model = WaterClassificationModel() @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint""" return jsonify({'status': 'healthy', 'model_loaded': True}) @app.route('/predict', methods=['POST']) def predict(): """Prediction endpoint""" try: # Get image from request if 'image' not in request.files: return jsonify({'error': 'No image provided'}), 400 image_file = request.files['image'] image_bytes = image_file.read() if not image_bytes: return jsonify({'error': 'Empty image'}), 400 # Make prediction result = model.predict(image_bytes) return jsonify(result) except Exception as e: print(f"❌ API Error: {e}") return jsonify({'error': str(e)}), 500 @app.route('/predict_base64', methods=['POST']) def predict_base64(): """Prediction endpoint for base64 encoded images""" try: data = request.get_json() if not data or 'image' not in data: return jsonify({'error': 'No image provided'}), 400 # Decode base64 image image_data = data['image'] if ',' in image_data: image_data = image_data.split(',')[1] # Remove data URL prefix image_bytes = base64.b64decode(image_data) # Make prediction result = model.predict(image_bytes) return jsonify(result) except Exception as e: print(f"❌ API Error: {e}") return jsonify({'error': str(e)}), 500 if __name__ == '__main__': print("🚀 Starting Water Classification API...") print("📡 Available endpoints:") print(" GET /health") print(" POST /predict (multipart form data)") print(" POST /predict_base64 (JSON with base64 image)") print("🌐 Server running on http://localhost:5000") app.run(host='0.0.0.0', port=5000, debug=True)