Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| # Configuration | |
| MODEL_URL = "https://huggingface.co/fahd9999/face_shape_classification/resolve/main/model_85_nn_.pth" | |
| MODEL_PATH = "model_85_nn_.pth" | |
| CLASS_NAMES = ['Heart', 'Oblong', 'Oval', 'Round', 'Square'] | |
| # Device configuration (Force CPU for Hugging Face Spaces free tier compatibility) | |
| DEVICE = torch.device('cpu') | |
| def download_model_if_not_exists(): | |
| """Download model from Hugging Face repository if it doesn't exist locally.""" | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"Model not found locally at {MODEL_PATH}, downloading from Hugging Face...") | |
| try: | |
| response = requests.get(MODEL_URL, stream=True) | |
| response.raise_for_status() | |
| with open(MODEL_PATH, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Model downloaded and saved to {MODEL_PATH}") | |
| except Exception as e: | |
| print(f"Failed to download model: {e}") | |
| raise | |
| else: | |
| print("Model already exists locally.") | |
| def load_model(): | |
| """Load model from the local path.""" | |
| download_model_if_not_exists() | |
| try: | |
| # Load model with map_location to ensure CPU usage | |
| model = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.eval() | |
| model.to(DEVICE) | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| # Global model instance | |
| model = None | |
| def get_model(): | |
| global model | |
| if model is None: | |
| model = load_model() | |
| return model | |
| def preprocess_image(image_file): | |
| """Preprocess image for model inference.""" | |
| transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = Image.open(image_file).convert("RGB") | |
| return transform(image).unsqueeze(0) | |
| def predict(image_file): | |
| """ | |
| Make prediction on an image file. | |
| Returns: | |
| dict: { | |
| "predicted_class": str, | |
| "confidence": float, | |
| "probabilities": dict | |
| } | |
| """ | |
| current_model = get_model() | |
| image_tensor = preprocess_image(image_file).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = current_model(image_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| confidences, predicted_indices = torch.max(probabilities, 1) | |
| predicted_index = predicted_indices.item() | |
| predicted_class = CLASS_NAMES[predicted_index] | |
| confidence_score = confidences.item() | |
| # Format all probabilities | |
| probs_dict = { | |
| name: prob.item() | |
| for name, prob in zip(CLASS_NAMES, probabilities[0]) | |
| } | |
| return { | |
| "predicted_class": predicted_class, | |
| "confidence": confidence_score, | |
| "probabilities": probs_dict | |
| } | |