File size: 2,172 Bytes
a5a6a2e
7f3db4a
a5a6a2e
 
7f3db4a
 
 
a5a6a2e
7f3db4a
 
 
 
 
 
 
 
 
 
 
 
a5a6a2e
 
 
7f3db4a
 
a5a6a2e
 
 
 
 
 
 
 
 
 
7f3db4a
 
a5a6a2e
 
7f3db4a
a5a6a2e
 
 
 
7f3db4a
 
a5a6a2e
7f3db4a
 
a5a6a2e
 
 
7f3db4a
 
 
a5a6a2e
 
 
7f3db4a
 
 
 
 
 
 
a5a6a2e
7f3db4a
 
 
 
 
 
 
 
 
a5a6a2e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import joblib
from huggingface_hub import hf_hub_download

from geometry import extract_features
from landmarks import get_landmarks

REPO_ID = "codernotme/kataria_optical"
MODEL_PATH = "face_shape_model.pkl"

# Global model cache
_model = None


def _get_feature_vector(features):
    return [
        features.get("lw_ratio", 0),
        features.get("jaw_ratio", 0),
        features.get("forehead_ratio", 0),
    ]


def load_model():
    global _model
    if _model is None:
        local_path = MODEL_PATH
        if not os.path.exists(local_path):
            try:
                print(f"Downloading {MODEL_PATH} from HF Hub...")
                local_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH)
            except Exception as e:
                print(f"Could not download from HF Hub: {e}")
                return None

        try:
            _model = joblib.load(local_path)
            print("Loaded face shape model.")
        except Exception as e:
            print(f"Failed to load model: {e}")
    return _model


def classify_face_shape(image_input):
    """
    Classifies face shape using the trained SVM model.

    Args:
        image_input: File path, PIL Image, or numpy array.

    Returns:
        dict: Sorted dictionary of probabilities.
    """
    model = load_model()

    if model is None or image_input is None:
        return {"Unknown": 1.0}

    try:
        landmarks = get_landmarks(image_input)
        feats = extract_features(landmarks)
        vector = _get_feature_vector(feats)

        probabilities = model.predict_proba([vector])[0]
        labels = list(getattr(model, "classes_", []))
        if not labels:
            return {"Unknown": 1.0}

        scores = {
            str(label): round(float(score), 4)
            for label, score in zip(labels, probabilities)
        }
        total_score = sum(scores.values()) or 1
        scores = {k: round(float(v / total_score), 4) for k, v in scores.items()}
        return dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))

    except Exception as e:
        print(f"Prediction error: {e}")
        return {"Error": 1.0}