Spaces:
Runtime error
Runtime error
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import tensorflow as tf | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import os | |
| import warnings | |
| import base64 | |
| import io | |
| from werkzeug.utils import secure_filename | |
| warnings.filterwarnings('ignore') | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Configure TensorFlow to use CPU only | |
| tf.config.set_visible_devices([], 'GPU') | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
| # Define face shape labels | |
| face_shape_labels = ['Heart', 'Oblong', 'Oval', 'Round', 'Square'] | |
| # Global variables for models | |
| face_detection_model = None | |
| # Define the model path (update this path according to your setup) | |
| model_path = './Try_Face_Detection_AI_1.keras' # Update this path | |
| ############################################################## | |
| # FACE DETECTION AND PROCESSING FUNCTIONS | |
| ############################################################## | |
| def detect_face_with_opencv(image): | |
| """Detect face using OpenCV's Haar Cascade""" | |
| if image is None: | |
| return None | |
| # Convert to numpy array if needed | |
| if not isinstance(image, np.ndarray): | |
| if hasattr(image, 'convert'): | |
| image = np.array(image.convert('RGB')) | |
| else: | |
| image = np.array(image) | |
| # Convert to grayscale for face detection | |
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| # Load OpenCV's face detector | |
| face_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' | |
| if not os.path.exists(face_cascade_path): | |
| print(f"Error: Haar cascade file not found at {face_cascade_path}") | |
| return None | |
| face_cascade = cv2.CascadeClassifier(face_cascade_path) | |
| # Detect faces | |
| faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) | |
| if len(faces) > 0: | |
| x, y, w, h = faces[0] # Get the first face | |
| face_img = image[y:y+h, x:x+w] | |
| return face_img | |
| else: | |
| return None | |
| def extract_face(image): | |
| """Extract face from image""" | |
| if image is None: | |
| return None | |
| face_img = detect_face_with_opencv(image) | |
| if face_img is not None: | |
| return cv2.resize(face_img, (224, 224)) | |
| # If OpenCV fails, use the whole image | |
| print("WARNING: Could not detect face with OpenCV") | |
| if isinstance(image, np.ndarray): | |
| resized = cv2.resize(image, (224, 224)) | |
| return resized | |
| elif hasattr(image, 'resize'): | |
| resized = image.resize((224, 224)) | |
| return np.array(resized) | |
| return None | |
| def preprocess_image(image): | |
| """Preprocess image for model input""" | |
| if image is None: | |
| return None | |
| try: | |
| if isinstance(image, np.ndarray): | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| else: | |
| rgb_image = image | |
| else: | |
| if hasattr(image, 'convert'): | |
| rgb_image = np.array(image.convert('RGB')) | |
| else: | |
| rgb_image = np.array(image) | |
| # Ensure image is the right shape | |
| if rgb_image.shape[0] != 224 or rgb_image.shape[1] != 224: | |
| resized_image = cv2.resize(rgb_image, (224, 224)) | |
| else: | |
| resized_image = rgb_image | |
| # Handle different channel formats | |
| if len(resized_image.shape) == 2: # Grayscale | |
| resized_image = cv2.cvtColor(resized_image, cv2.COLOR_GRAY2RGB) | |
| elif resized_image.shape[2] == 4: # RGBA | |
| resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGBA2RGB) | |
| normalized_image = resized_image / 255.0 | |
| image_batch = np.expand_dims(normalized_image, axis=0) | |
| return image_batch | |
| except Exception as e: | |
| print(f"Error in image preprocessing: {e}") | |
| return None | |
| def load_face_shape_model(): | |
| """Load face shape detection model""" | |
| global face_detection_model | |
| try: | |
| # Force CPU usage to avoid CUDA issues | |
| with tf.device('/CPU:0'): | |
| face_detection_model = tf.keras.models.load_model(model_path) | |
| print("Face shape detection model loaded successfully!") | |
| return face_detection_model | |
| except Exception as e: | |
| print(f"Warning: Could not load face shape model: {e}") | |
| # Create a dummy model for testing if real one isn't available | |
| face_detection_model = tf.keras.Sequential([ | |
| tf.keras.layers.Input(shape=(224, 224, 3)), | |
| tf.keras.layers.Conv2D(16, 3, activation='relu'), | |
| tf.keras.layers.GlobalAveragePooling2D(), | |
| tf.keras.layers.Dense(5, activation='softmax') | |
| ]) | |
| print("Created dummy face shape model for testing") | |
| return face_detection_model | |
| def predict_face_shape(image): | |
| """Predict face shape using the loaded model""" | |
| global face_detection_model | |
| if image is None: | |
| return {"error": "No image provided"} | |
| # Extract face from image | |
| face_image = extract_face(image) | |
| if face_image is None: | |
| return {"error": "Could not process the face in the image"} | |
| # Load model if not loaded | |
| if face_detection_model is None: | |
| try: | |
| face_detection_model = load_face_shape_model() | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return {"error": "Could not load the face shape detection model"} | |
| try: | |
| # Preprocess the image | |
| preprocessed_image = preprocess_image(face_image) | |
| if preprocessed_image is None: | |
| return {"error": "Could not process the image"} | |
| # Make prediction - Force CPU usage | |
| with tf.device('/CPU:0'): | |
| predictions = face_detection_model.predict(preprocessed_image) | |
| predicted_class = np.argmax(predictions) | |
| confidence = float(predictions[0][predicted_class]) * 100 | |
| return { | |
| "face_shape": face_shape_labels[predicted_class], | |
| "confidence": round(confidence, 1) | |
| } | |
| except Exception as e: | |
| print(f"Error in face shape prediction: {e}") | |
| # Provide a default face shape when model fails | |
| return { | |
| "face_shape": "Oval", | |
| "confidence": 50.0, | |
| "note": "Default prediction due to processing error" | |
| } | |
| ############################################################## | |
| # RECOMMENDATION DATA | |
| ############################################################## | |
| face_shape_recommendations = { | |
| "Heart": { | |
| "Glasses": [ | |
| "Cat Eye Frames", "Round Frames", "Clear Frames", "Oval Glasses", "Alford Glasses", | |
| "Tortoiseshell Sunglasses", "Transparent Eyeglasses Frames", "Geometric Frames", | |
| "Aviator Glasses", "Clubmaster Frames", "Oversized Glasses", "Square Frames", | |
| "Wayfarer Glasses", "Browline Glasses", "Rimless Glasses", "Classic Aviators", | |
| "Butterfly Frames", "Pantos Frames", "Pilot Glasses", "Rectangle Frames" | |
| ], | |
| "Watches": [ | |
| "Luxury Watch", "Minimalist Watch", "Chronograph Watch", "Pilot Watch", "Diver Watch", | |
| "Sveston Sports Watch", "Casio G-Shock", "Casio Edifice", "Casio Protrek", "Fossil Silicon Watch", | |
| "Swiss Military Alpine", "Hanowa Puma Watch", "Swiss Chronograph", "Smart BT Calling Watch", | |
| "Infinity Smart Watch", "Vogue Smart Watch", "Realme Watch S2", "Mibro Watch C4", | |
| "Redmi Watch 5", "Bold Dial Watch" | |
| ], | |
| "Hats": [ | |
| "Beanie", "Wide-Brim Hat", "Trilby", "Newsboy Cap", "Cowboy Hat", | |
| "Trucker Hat", "Safari Hat", "Flat Cap", "Boater Hat", "Top Hat", | |
| "Classic Fedora", "Chitrali Cap", "Gilgiti Cap", "Pakol", "Baseball Cap", | |
| "Snapback Cap", "Bucket Hat", "Beret", "Panama Hat", "Pork Pie Hat" | |
| ] | |
| }, | |
| "Oblong": { | |
| "Glasses": [ | |
| "Aviators", "Oversized Glasses", "Round Frames", "Square Frames", "Wayfarer Glasses", | |
| "Tortoiseshell Sunglasses", "Transparent Eyeglasses Frames", "Geometric Frames", | |
| "Cat Eye Frames", "Clubmaster Frames", "Oval Glasses", "Clear Frames", | |
| "Butterfly Frames", "Pantos Frames", "Pilot Glasses", "Rectangle Frames", | |
| "Browline Glasses", "Rimless Glasses", "Classic Aviators", "Embellished Sunglasses" | |
| ], | |
| "Watches": [ | |
| "Pilot Watch", "Luxury Watch", "Minimalist Watch", "Chronograph Watch", "Diver Watch", | |
| "Sveston Sports Watch", "Casio G-Shock", "Casio Edifice", "Casio Protrek", "Fossil Silicon Watch", | |
| "Swiss Military Alpine", "Hanowa Puma Watch", "Swiss Chronograph", "Smart BT Calling Watch", | |
| "Infinity Smart Watch", "Vogue Smart Watch", "Realme Watch S2", "Mibro Watch C4", | |
| "Redmi Watch 5", "Bold Dial Watch" | |
| ], | |
| "Hats": [ | |
| "Trilby", "Newsboy Cap", "Cowboy Hat", "Safari Hat", "Flat Cap", | |
| "Trucker Hat", "Beanie", "Wide-Brim Hat", "Boater Hat", "Top Hat", | |
| "Classic Fedora", "Chitrali Cap", "Gilgiti Cap", "Pakol", "Baseball Cap", | |
| "Snapback Cap", "Bucket Hat", "Beret", "Panama Hat", "Pork Pie Hat" | |
| ] | |
| }, | |
| "Oval": { | |
| "Glasses": [ | |
| "Wayfarer Glasses", "Geometric Frames", "Cat Eye Frames", "Round Frames", "Clear Frames", | |
| "Aviator Glasses", "Clubmaster Frames", "Square Frames", "Oversized Glasses", "Oval Glasses", | |
| "Transparent Frames", "Tortoiseshell Frames", "Browline Glasses", "Classic Aviators", | |
| "Butterfly Frames", "Rimless Glasses", "Rectangle Frames", "Pilot Glasses", | |
| "Metal Frame Glasses", "Gradient Sunglasses" | |
| ], | |
| "Watches": [ | |
| "Diver Watch", "Dress Watch", "Luxury Watch", "Minimalist Watch", "Chronograph Watch", | |
| "Smart BT Calling Watch", "Realme Watch S2", "Fossil Gen 6 Smartwatch", "Casio Edifice", | |
| "Swiss Military Alpine", "Sveston Classic", "Hanowa Chronograph", "Infinity Smart Watch", | |
| "Mibro T1 Smartwatch", "Vogue Smart Watch", "T500+ Smart Watch", "Casio F91W", | |
| "Xiaomi Watch 2", "Skeleton Watch", "Bold Dial Watch" | |
| ], | |
| "Hats": [ | |
| "Cowboy Hat", "Safari Hat", "Trilby", "Newsboy Cap", "Flat Cap", | |
| "Wide-Brim Hat", "Boater Hat", "Top Hat", "Classic Fedora", "Pakol", | |
| "Gilgiti Cap", "Baseball Cap", "Bucket Hat", "Snapback Cap", "Beret", | |
| "Panama Hat", "Pork Pie Hat", "Sun Hat", "Chitrali Cap", "Trucker Hat" | |
| ] | |
| }, | |
| "Round": { | |
| "Glasses": [ | |
| "Square Frames", "Browline Glasses", "Cat Eye Frames", "Round Frames", "Clear Frames", | |
| "Wayfarer Glasses", "Geometric Frames", "Clubmaster Frames", "Rectangle Frames", | |
| "Tortoiseshell Frames", "Metal Frame Glasses", "Oversized Glasses", "Aviator Glasses", | |
| "Butterfly Frames", "Classic Aviators", "Transparent Frames", "Rimless Glasses", | |
| "Oval Glasses", "Pilot Glasses", "Gradient Sunglasses" | |
| ], | |
| "Watches": [ | |
| "Bold Dial Watch", "Square Dial Watch", "Luxury Watch", "Minimalist Watch", "Chronograph Watch", | |
| "Casio G-Shock", "Sveston Classic Watch", "Swiss Military Alpine", "Hanowa Smart Watch", | |
| "Infinity Smart Watch", "Fossil Smart Watch", "Realme Watch S2", "Mibro T1 Smartwatch", | |
| "Dress Watch", "Smart BT Calling Watch", "Casio Edifice", "Vogue Smart Watch", | |
| "T500+ Smart Watch", "Skeleton Watch", "Retro Watch" | |
| ], | |
| "Hats": [ | |
| "Flat Cap", "Boater Hat", "Trilby", "Newsboy Cap", "Cowboy Hat", | |
| "Wide-Brim Hat", "Safari Hat", "Classic Fedora", "Pakol", "Chitrali Cap", | |
| "Snapback Cap", "Bucket Hat", "Top Hat", "Baseball Cap", "Panama Hat", | |
| "Pork Pie Hat", "Sun Hat", "Beret", "Trucker Hat", "Gilgiti Cap" | |
| ] | |
| }, | |
| "Square": { | |
| "Glasses": [ | |
| "Rimless Glasses", "Classic Aviators", "Cat Eye Frames", "Round Frames", "Clear Frames", | |
| "Wayfarer Glasses", "Geometric Frames", "Clubmaster Frames", "Square Frames", "Tortoiseshell Glasses", | |
| "Aviator Glasses", "Browline Glasses", "Transparent Frames", "Butterfly Frames", | |
| "Rectangle Frames", "Pilot Glasses", "Metal Frame Glasses", "Oversized Frames", | |
| "Oval Glasses", "Gradient Sunglasses" | |
| ], | |
| "Watches": [ | |
| "Skeleton Watch", "Retro Watch", "Luxury Watch", "Minimalist Watch", "Chronograph Watch", | |
| "Dress Watch", "Casio Edifice", "Smart BT Calling Watch", "Infinity Smart Watch", | |
| "Realme Watch S2", "Fossil Gen 6", "Mibro T1", "Swiss Military Alpine", | |
| "Hanowa Puma Watch", "Casio G-Shock", "Redmi Watch 5", "Vogue Smart Watch", | |
| "Bold Dial Watch", "Square Dial Watch", "Pilot Watch" | |
| ], | |
| "Hats": [ | |
| "Top Hat", "Classic Fedora", "Trilby", "Newsboy Cap", "Cowboy Hat", | |
| "Flat Cap", "Safari Hat", "Boater Hat", "Snapback Cap", "Bucket Hat", | |
| "Baseball Cap", "Panama Hat", "Pork Pie Hat", "Beret", "Sun Hat", | |
| "Wide-Brim Hat", "Trucker Hat", "Chitrali Cap", "Pakol", "Gilgiti Cap" | |
| ] | |
| } | |
| } | |
| ############################################################## | |
| # API ROUTES | |
| ############################################################## | |
| def home(): | |
| """Health check endpoint""" | |
| return jsonify({ | |
| "message": "AI Fashion Recommendation API is running!", | |
| "version": "1.0", | |
| "endpoints": { | |
| "image_recommendations": "/predict/image", | |
| "text_recommendations": "/predict/text", | |
| "face_shape_detection": "/detect/face-shape" | |
| } | |
| }) | |
| def predict_image_recommendations(): | |
| """Get fashion recommendations based on uploaded image""" | |
| try: | |
| # Check if image is provided | |
| if 'image' not in request.files and 'image_base64' not in request.json: | |
| return jsonify({"error": "No image provided"}), 400 | |
| # Get categories | |
| categories = request.form.getlist('categories') if 'categories' in request.form else [] | |
| # If using JSON with base64 image | |
| if request.is_json: | |
| data = request.get_json() | |
| categories = data.get('categories', []) | |
| if 'image_base64' in data: | |
| # Decode base64 image | |
| image_data = base64.b64decode(data['image_base64']) | |
| image = Image.open(io.BytesIO(image_data)) | |
| else: | |
| return jsonify({"error": "No image provided"}), 400 | |
| else: | |
| # Handle file upload | |
| image_file = request.files['image'] | |
| image = Image.open(image_file.stream) | |
| if not categories: | |
| return jsonify({"error": "Please select at least one product category"}), 400 | |
| # Predict face shape | |
| face_shape_result = predict_face_shape(image) | |
| if "error" in face_shape_result: | |
| face_shape = "Oval" # Default | |
| face_shape_info = { | |
| "face_shape": face_shape, | |
| "confidence": 50.0, | |
| "note": "Using default face shape due to detection error" | |
| } | |
| else: | |
| face_shape = face_shape_result["face_shape"] | |
| face_shape_info = face_shape_result | |
| # Get recommendations | |
| recommendations = {} | |
| for category in categories: | |
| face_rec = face_shape_recommendations.get(face_shape, {}).get(category, []) | |
| recommendations[category] = face_rec[:5] if face_rec else [] | |
| return jsonify({ | |
| "face_shape_info": face_shape_info, | |
| "recommendations": recommendations, | |
| "categories": categories | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": f"Internal server error: {str(e)}"}), 500 | |
| def predict_text_recommendations(): | |
| """Get fashion recommendations based on text attributes""" | |
| try: | |
| data = request.get_json() | |
| gender = data.get('gender') | |
| skin_tone = data.get('skin_tone') | |
| age_group = data.get('age_group') | |
| categories = data.get('categories', []) | |
| if not categories: | |
| return jsonify({"error": "Please select at least one product category"}), 400 | |
| # For text-based recommendations, use Oval as default face shape | |
| recommendations = {} | |
| for category in categories: | |
| face_rec = face_shape_recommendations.get("Oval", {}).get(category, []) | |
| recommendations[category] = face_rec[:5] if face_rec else [] | |
| return jsonify({ | |
| "user_attributes": { | |
| "gender": gender, | |
| "skin_tone": skin_tone, | |
| "age_group": age_group | |
| }, | |
| "recommendations": recommendations, | |
| "categories": categories, | |
| "note": "Recommendations based on general fashion trends" | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": f"Internal server error: {str(e)}"}), 500 | |
| def detect_face_shape_only(): | |
| """Detect face shape from uploaded image""" | |
| try: | |
| # Check if image is provided | |
| if 'image' not in request.files and 'image_base64' not in request.json: | |
| return jsonify({"error": "No image provided"}), 400 | |
| # Handle different input methods | |
| if request.is_json: | |
| data = request.get_json() | |
| if 'image_base64' in data: | |
| # Decode base64 image | |
| image_data = base64.b64decode(data['image_base64']) | |
| image = Image.open(io.BytesIO(image_data)) | |
| else: | |
| return jsonify({"error": "No image provided"}), 400 | |
| else: | |
| # Handle file upload | |
| image_file = request.files['image'] | |
| image = Image.open(image_file.stream) | |
| # Predict face shape | |
| face_shape_result = predict_face_shape(image) | |
| return jsonify(face_shape_result) | |
| except Exception as e: | |
| return jsonify({"error": f"Internal server error: {str(e)}"}), 500 | |
| def get_categories(): | |
| """Get available product categories""" | |
| return jsonify({ | |
| "categories": ["Glasses", "Watches", "Hats"], | |
| "face_shapes": face_shape_labels, | |
| "gender_options": ["Male", "Female", "Kid", "Transgender"], | |
| "skin_tone_options": ["Fair", "Medium", "Dark"], | |
| "age_group_options": ["Child (0-12)", "Teen (13-19)", "Young Adult (20-35)", "Adult (36-50)", "Senior (51+)"] | |
| }) | |
| ############################################################## | |
| # MAIN EXECUTION | |
| ############################################################## | |
| if __name__ == '__main__': | |
| # Load the face shape detection model on startup | |
| print("Loading face shape detection model...") | |
| load_face_shape_model() | |
| print("API is ready!") | |
| # Run the Flask app | |
| app.run(host='0.0.0.0', port=5000, debug=True) |