Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import tensorflow as tf | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import os | |
| import warnings | |
| import base64 | |
| import io | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict | |
| warnings.filterwarnings('ignore') | |
| # Initialize FastAPI app | |
| app = FastAPI(title="AI Fashion Recommendation API") | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Configure TensorFlow to use CPU only | |
| tf.config.set_visible_devices([], 'GPU') | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
| # Define face shape labels | |
| face_shape_labels = ['Heart', 'Oblong', 'Oval', 'Round', 'Square'] | |
| # Global variables for models | |
| face_detection_model = None | |
| # Define the model path (update this path according to your setup) | |
| model_path = './Try_Face_Detection_AI_1.keras' # Update this path | |
| # Pydantic models for request validation | |
| class TextRecommendationRequest(BaseModel): | |
| gender: str | |
| skin_tone: str | |
| age_group: str | |
| categories: List[str] | |
| class Base64ImageRequest(BaseModel): | |
| image_base64: str | |
| categories: List[str] | |
| ############################################################## | |
| # FACE DETECTION AND PROCESSING FUNCTIONS | |
| ############################################################## | |
| def detect_face_with_opencv(image): | |
| """Detect face using OpenCV's Haar Cascade""" | |
| if image is None: | |
| return None | |
| if not isinstance(image, np.ndarray): | |
| if hasattr(image, 'convert'): | |
| image = np.array(image.convert('RGB')) | |
| else: | |
| image = np.array(image) | |
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| face_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' | |
| if not os.path.exists(face_cascade_path): | |
| raise HTTPException(status_code=500, detail="Haar cascade file not found") | |
| face_cascade = cv2.CascadeClassifier(face_cascade_path) | |
| faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) | |
| if len(faces) > 0: | |
| x, y, w, h = faces[0] | |
| return image[y:y+h, x:x+w] | |
| return None | |
| def extract_face(image): | |
| face_img = detect_face_with_opencv(image) | |
| if face_img is not None: | |
| return cv2.resize(face_img, (224, 224)) | |
| print("WARNING: Could not detect face with OpenCV") | |
| if isinstance(image, np.ndarray): | |
| return cv2.resize(image, (224, 224)) | |
| elif hasattr(image, 'resize'): | |
| return np.array(image.resize((224, 224))) | |
| return None | |
| def preprocess_image(image): | |
| try: | |
| if isinstance(image, np.ndarray): | |
| rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if len(image.shape) == 3 and image.shape[2] == 3 else image | |
| else: | |
| rgb_image = np.array(image.convert('RGB')) if hasattr(image, 'convert') else np.array(image) | |
| if rgb_image.shape[0] != 224 or rgb_image.shape[1] != 224: | |
| rgb_image = cv2.resize(rgb_image, (224, 224)) | |
| if len(rgb_image.shape) == 2: | |
| rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_GRAY2RGB) | |
| elif rgb_image.shape[2] == 4: | |
| rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGBA2RGB) | |
| normalized_image = rgb_image / 255.0 | |
| return np.expand_dims(normalized_image, axis=0) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Image preprocessing failed: {str(e)}") | |
| def load_face_shape_model(): | |
| global face_detection_model | |
| try: | |
| with tf.device('/CPU:0'): | |
| face_detection_model = tf.keras.models.load_model(model_path) | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Warning: Could not load model: {e}") | |
| face_detection_model = tf.keras.Sequential([ | |
| tf.keras.layers.Input(shape=(224, 224, 3)), | |
| tf.keras.layers.Conv2D(16, 3, activation='relu'), | |
| tf.keras.layers.GlobalAveragePooling2D(), | |
| tf.keras.layers.Dense(5, activation='softmax') | |
| ]) | |
| print("Created dummy model") | |
| def predict_face_shape(image): | |
| global face_detection_model | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="No image provided") | |
| face_image = extract_face(image) | |
| if face_image is None: | |
| return {"face_shape": "Oval", "confidence": 50.0, "note": "Default due to face detection error"} | |
| try: | |
| preprocessed_image = preprocess_image(face_image) | |
| with tf.device('/CPU:0'): | |
| predictions = face_detection_model.predict(preprocessed_image) | |
| predicted_class = np.argmax(predictions) | |
| confidence = float(predictions[0][predicted_class]) * 100 | |
| return { | |
| "face_shape": face_shape_labels[predicted_class], | |
| "confidence": round(confidence, 1) | |
| } | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| return {"face_shape": "Oval", "confidence": 50.0, "note": "Default due to prediction error"} | |
| ############################################################## | |
| # RECOMMENDATION DATA (Same as original) | |
| ############################################################## | |
| face_shape_recommendations = { | |
| "Heart": { | |
| "Glasses": [ | |
| "Cat Eye Frames", "Round Frames", "Clear Frames", "Oval Glasses", "Alford Glasses", | |
| "Tortoiseshell Sunglasses", "Transparent Eyeglasses Frames", "Geometric Frames", | |
| "Aviator Glasses", "Clubmaster Frames", "Oversized Glasses", "Square Frames", | |
| "Wayfarer Glasses", "Browline Glasses", "Rimless Glasses", "Classic Aviators", | |
| "Butterfly Frames", "Pantos Frames", "Pilot Glasses", "Rectangle Frames" | |
| ], | |
| "Watches": [ | |
| "Luxury Watch", "Minimalist Watch", "Chronograph Watch", "Pilot Watch", "Diver Watch", | |
| "Sveston Sports Watch", "Casio G-Shock", "Casio Edifice", "Casio Protrek", "Fossil Silicon Watch", | |
| "Swiss Military Alpine", "Hanowa Puma Watch", "Swiss Chronograph", "Smart BT Calling Watch", | |
| "Infinity Smart Watch", "Vogue Smart Watch", "Realme Watch S2", "Mibro Watch C4", | |
| "Redmi Watch 5", "Bold Dial Watch" | |
| ], | |
| "Hats": [ | |
| "Beanie", "Wide-Brim Hat", "Trilby", "Newsboy Cap", "Cowboy Hat", | |
| "Trucker Hat", "Safari Hat", "Flat Cap", "Boater Hat", "Top Hat", | |
| "Classic Fedora", "Chitrali Cap", "Gilgiti Cap", "Pakol", "Baseball Cap", | |
| "Snapback Cap", "Bucket Hat", "Beret", "Panama Hat", "Pork Pie Hat" | |
| ] | |
| }, | |
| "Oblong": { | |
| "Glasses": [ | |
| "Aviators", "Oversized Glasses", "Round Frames", "Square Frames", "Wayfarer Glasses", | |
| "Tortoiseshell Sunglasses", "Transparent Eyeglasses Frames", "Geometric Frames", | |
| "Cat Eye Frames", "Clubmaster Frames", "Oval Glasses", "Clear Frames", | |
| "Butterfly Frames", "Pantos Frames", "Pilot Glasses", "Rectangle Frames", | |
| "Browline Glasses", "Rimless Glasses", "Classic Aviators", "Embellished Sunglasses" | |
| ], | |
| "Watches": [ | |
| "Pilot Watch", "Luxury Watch", "Minimalist Watch", "Chronograph Watch", "Diver Watch", | |
| "Sveston Sports Watch", "Casio G-Shock", "Casio Edifice", "Casio Protrek", "Fossil Silicon Watch", | |
| "Swiss Military Alpine", "Hanowa Puma Watch", "Swiss Chronograph", "Smart BT Calling Watch", | |
| "Infinity Smart Watch", "Vogue Smart Watch", "Realme Watch S2", "Mibro Watch C4", | |
| "Redmi Watch 5", "Bold Dial Watch" | |
| ], | |
| "Hats": [ | |
| "Trilby", "Newsboy Cap", "Cowboy Hat", "Safari Hat", "Flat Cap", | |
| "Trucker Hat", "Beanie", "Wide-Brim Hat", "Boater Hat", "Top Hat", | |
| "Classic Fedora", "Chitrali Cap", "Gilgiti Cap", "Pakol", "Baseball Cap", | |
| "Snapback Cap", "Bucket Hat", "Beret", "Panama Hat", "Pork Pie Hat" | |
| ] | |
| }, | |
| "Oval": { | |
| "Glasses": [ | |
| "Wayfarer Glasses", "Geometric Frames", "Cat Eye Frames", "Round Frames", "Clear Frames", | |
| "Aviator Glasses", "Clubmaster Frames", "Square Frames", "Oversized Glasses", "Oval Glasses", | |
| "Transparent Frames", "Tortoiseshell Frames", "Browline Glasses", "Classic Aviators", | |
| "Butterfly Frames", "Rimless Glasses", "Rectangle Frames", "Pilot Glasses", | |
| "Metal Frame Glasses", "Gradient Sunglasses" | |
| ], | |
| "Watches": [ | |
| "Diver Watch", "Dress Watch", "Luxury Watch", "Minimalist Watch", "Chronograph Watch", | |
| "Smart BT Calling Watch", "Realme Watch S2", "Fossil Gen 6 Smartwatch", "Casio Edifice", | |
| "Swiss Military Alpine", "Sveston Classic", "Hanowa Chronograph", "Infinity Smart Watch", | |
| "Mibro T1 Smartwatch", "Vogue Smart Watch", "T500+ Smart Watch", "Casio F91W", | |
| "Xiaomi Watch 2", "Skeleton Watch", "Bold Dial Watch" | |
| ], | |
| "Hats": [ | |
| "Cowboy Hat", "Safari Hat", "Trilby", "Newsboy Cap", "Flat Cap", | |
| "Wide-Brim Hat", "Boater Hat", "Top Hat", "Classic Fedora", "Pakol", | |
| "Gilgiti Cap", "Baseball Cap", "Bucket Hat", "Snapback Cap", "Beret", | |
| "Panama Hat", "Pork Pie Hat", "Sun Hat", "Chitrali Cap", "Trucker Hat" | |
| ] | |
| }, | |
| "Round": { | |
| "Glasses": [ | |
| "Square Frames", "Browline Glasses", "Cat Eye Frames", "Round Frames", "Clear Frames", | |
| "Wayfarer Glasses", "Geometric Frames", "Clubmaster Frames", "Rectangle Frames", | |
| "Tortoiseshell Frames", "Metal Frame Glasses", "Oversized Glasses", "Aviator Glasses", | |
| "Butterfly Frames", "Classic Aviators", "Transparent Frames", "Rimless Glasses", | |
| "Oval Glasses", "Pilot Glasses", "Gradient Sunglasses" | |
| ], | |
| "Watches": [ | |
| "Bold Dial Watch", "Square Dial Watch", "Luxury Watch", "Minimalist Watch", "Chronograph Watch", | |
| "Casio G-Shock", "Sveston Classic Watch", "Swiss Military Alpine", "Hanowa Smart Watch", | |
| "Infinity Smart Watch", "Fossil Smart Watch", "Realme Watch S2", "Mibro T1 Smartwatch", | |
| "Dress Watch", "Smart BT Calling Watch", "Casio Edifice", "Vogue Smart Watch", | |
| "T500+ Smart Watch", "Skeleton Watch", "Retro Watch" | |
| ], | |
| "Hats": [ | |
| "Flat Cap", "Boater Hat", "Trilby", "Newsboy Cap", "Cowboy Hat", | |
| "Wide-Brim Hat", "Safari Hat", "Classic Fedora", "Pakol", "Chitrali Cap", | |
| "Snapback Cap", "Bucket Hat", "Top Hat", "Baseball Cap", "Panama Hat", | |
| "Pork Pie Hat", "Sun Hat", "Beret", "Trucker Hat", "Gilgiti Cap" | |
| ] | |
| }, | |
| "Square": { | |
| "Glasses": [ | |
| "Rimless Glasses", "Classic Aviators", "Cat Eye Frames", "Round Frames", "Clear Frames", | |
| "Wayfarer Glasses", "Geometric Frames", "Clubmaster Frames", "Square Frames", "Tortoiseshell Glasses", | |
| "Aviator Glasses", "Browline Glasses", "Transparent Frames", "Butterfly Frames", | |
| "Rectangle Frames", "Pilot Glasses", "Metal Frame Glasses", "Oversized Frames", | |
| "Oval Glasses", "Gradient Sunglasses" | |
| ], | |
| "Watches": [ | |
| "Skeleton Watch", "Retro Watch", "Luxury Watch", "Minimalist Watch", "Chronograph Watch", | |
| "Dress Watch", "Casio Edifice", "Smart BT Calling Watch", "Infinity Smart Watch", | |
| "Realme Watch S2", "Fossil Gen 6", "Mibro T1", "Swiss Military Alpine", | |
| "Hanowa Puma Watch", "Casio G-Shock", "Redmi Watch 5", "Vogue Smart Watch", | |
| "Bold Dial Watch", "Square Dial Watch", "Pilot Watch" | |
| ], | |
| "Hats": [ | |
| "Top Hat", "Classic Fedora", "Trilby", "Newsboy Cap", "Cowboy Hat", | |
| "Flat Cap", "Safari Hat", "Boater Hat", "Snapback Cap", "Bucket Hat", | |
| "Baseball Cap", "Panama Hat", "Pork Pie Hat", "Beret", "Sun Hat", | |
| "Wide-Brim Hat", "Trucker Hat", "Chitrali Cap", "Pakol", "Gilgiti Cap" | |
| ] | |
| } | |
| } | |
| ############################################################## | |
| # API ENDPOINTS | |
| ############################################################## | |
| async def load_model(): | |
| print("Loading face shape detection model...") | |
| load_face_shape_model() | |
| print("API ready!") | |
| async def home(): | |
| return { | |
| "message": "AI Fashion Recommendation API is running!", | |
| "version": "1.0", | |
| "endpoints": { | |
| "image_recommendations": "/predict/image", | |
| "text_recommendations": "/predict/text", | |
| "face_shape_detection": "/detect/face-shape" | |
| } | |
| } | |
| async def image_recommendations( | |
| request: Request, | |
| image: Optional[UploadFile] = File(None), | |
| categories: List[str] = [] | |
| ): | |
| try: | |
| # Handle base64 or file upload | |
| if await request.body(): | |
| data = await request.json() | |
| if 'image_base64' in data: | |
| image_data = base64.b64decode(data['image_base64']) | |
| image = Image.open(io.BytesIO(image_data)) | |
| categories = data.get('categories', []) | |
| else: | |
| raise HTTPException(status_code=400, detail="No image provided") | |
| elif image: | |
| contents = await image.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| else: | |
| raise HTTPException(status_code=400, detail="No image provided") | |
| if not categories: | |
| raise HTTPException(status_code=400, detail="Select at least one category") | |
| face_shape_result = predict_face_shape(image) | |
| face_shape = face_shape_result.get("face_shape", "Oval") | |
| recommendations = {} | |
| for category in categories: | |
| recs = face_shape_recommendations.get(face_shape, {}).get(category, []) | |
| recommendations[category] = recs[:5] | |
| return { | |
| "face_shape_info": face_shape_result, | |
| "recommendations": recommendations, | |
| "categories": categories | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def text_recommendations(request: TextRecommendationRequest): | |
| try: | |
| recommendations = {} | |
| for category in request.categories: | |
| recs = face_shape_recommendations.get("Oval", {}).get(category, []) | |
| recommendations[category] = recs[:5] | |
| return { | |
| "user_attributes": request.dict(), | |
| "recommendations": recommendations, | |
| "note": "General fashion trends" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def detect_face_shape( | |
| image: Optional[UploadFile] = File(None), | |
| request: Optional[Base64ImageRequest] = None | |
| ): | |
| try: | |
| if request and request.image_base64: | |
| image_data = base64.b64decode(request.image_base64) | |
| image = Image.open(io.BytesIO(image_data)) | |
| elif image: | |
| contents = await image.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| else: | |
| raise HTTPException(status_code=400, detail="No image provided") | |
| return predict_face_shape(image) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_categories(): | |
| return { | |
| "categories": ["Glasses", "Watches", "Hats"], | |
| "face_shapes": face_shape_labels, | |
| "gender_options": ["Male", "Female", "Kid", "Transgender"], | |
| "skin_tone_options": ["Fair", "Medium", "Dark"], | |
| "age_group_options": ["Child (0-12)", "Teen (13-19)", "Young Adult (20-35)", "Adult (36-50)", "Senior (51+)"] | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=5000) |