Spaces:
Sleeping
Sleeping
File size: 2,993 Bytes
e2d9248 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import os
import requests
import torch
import torchvision.transforms as T
from PIL import Image
import torch.nn.functional as F
# Configuration
MODEL_URL = "https://huggingface.co/fahd9999/face_shape_classification/resolve/main/model_85_nn_.pth"
MODEL_PATH = "model_85_nn_.pth"
CLASS_NAMES = ['Heart', 'Oblong', 'Oval', 'Round', 'Square']
# Device configuration (Force CPU for Hugging Face Spaces free tier compatibility)
DEVICE = torch.device('cpu')
def download_model_if_not_exists():
"""Download model from Hugging Face repository if it doesn't exist locally."""
if not os.path.exists(MODEL_PATH):
print(f"Model not found locally at {MODEL_PATH}, downloading from Hugging Face...")
try:
response = requests.get(MODEL_URL, stream=True)
response.raise_for_status()
with open(MODEL_PATH, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Model downloaded and saved to {MODEL_PATH}")
except Exception as e:
print(f"Failed to download model: {e}")
raise
else:
print("Model already exists locally.")
def load_model():
"""Load model from the local path."""
download_model_if_not_exists()
try:
# Load model with map_location to ensure CPU usage
model = torch.load(MODEL_PATH, map_location=DEVICE)
model.eval()
model.to(DEVICE)
return model
except Exception as e:
print(f"Error loading model: {e}")
raise
# Global model instance
model = None
def get_model():
global model
if model is None:
model = load_model()
return model
def preprocess_image(image_file):
"""Preprocess image for model inference."""
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image_file).convert("RGB")
return transform(image).unsqueeze(0)
def predict(image_file):
"""
Make prediction on an image file.
Returns:
dict: {
"predicted_class": str,
"confidence": float,
"probabilities": dict
}
"""
current_model = get_model()
image_tensor = preprocess_image(image_file).to(DEVICE)
with torch.no_grad():
outputs = current_model(image_tensor)
probabilities = F.softmax(outputs, dim=1)
confidences, predicted_indices = torch.max(probabilities, 1)
predicted_index = predicted_indices.item()
predicted_class = CLASS_NAMES[predicted_index]
confidence_score = confidences.item()
# Format all probabilities
probs_dict = {
name: prob.item()
for name, prob in zip(CLASS_NAMES, probabilities[0])
}
return {
"predicted_class": predicted_class,
"confidence": confidence_score,
"probabilities": probs_dict
}
|