File size: 4,411 Bytes
af35098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import io
import base64
import torch
from PIL import Image
from torchvision import transforms
from fastapi import HTTPException
from services.inference import BigFiveRegressor
from schemas.predict import OCEANTraits, PredictionResponse
from services.face_extractor import FaceExtractor

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DEVICE = "cpu"
class ModelManager:
    def __init__(self):
        self.models = {}
        self.transforms_dict = {}
        self.model_configs = {}
        try:
            self.face_extractor = FaceExtractor()
        except Exception as e:
            print(f"Warning: Failed to initialize FaceExtractor: {e}")
            self.face_extractor = None

    def load_hf_model_pipeline(self, model_key: str, repo_id: str, model_info: dict = None):
        """Loads model from Hugging Face and creates its specific preprocessing transform."""
        try:
            model = BigFiveRegressor.from_pretrained(repo_id)
            model.to(DEVICE)
            model.eval()

            # SwinV2 uses 256x256, ViT/PVTv2 use 224x224
            IMG_SIZE = 256 if 'swinv2' in model_key else 224
            transform = transforms.Compose([
                transforms.Resize((IMG_SIZE, IMG_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            
            self.models[model_key] = model
            self.transforms_dict[model_key] = transform
            if model_info:
                self.model_configs[model_key] = model_info
            print(f"✅ Loaded {model_key.upper()} from {repo_id}")
        except Exception as e:
            print(f"⚠️ Failed to load {model_key} from {repo_id}. Error: {e}")

    def predict(self, model_type: str, image_base64: str) -> PredictionResponse:
        model_type_lower = model_type.lower()
        if model_type_lower not in self.models:
            raise HTTPException(status_code=400, detail=f"Invalid model type. Choose from: {list(self.models.keys())}")

        # Decode Base64 to Image
        try:
            # Strip header if frontend accidentally includes "data:image/jpeg;base64,"
            base64_data = image_base64.split(",")[-1] 
            image_data = base64.b64decode(base64_data)
            image = Image.open(io.BytesIO(image_data)).convert("RGB")
        except Exception:
            raise HTTPException(status_code=400, detail="Invalid Base64 image payload.")

        # Face Extraction
        cropped_base64 = None
        if self.face_extractor:
            image = self.face_extractor.extract_main_face(image)
            if image is None:
                raise HTTPException(status_code=400, detail="No face detected in the image.")
            
            # Convert back to base64 for response
            buffered = io.BytesIO()
            image.save(buffered, format="JPEG")
            cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

        # Transform and Infer
        transform = self.transforms_dict[model_type_lower]
        input_tensor = transform(image).unsqueeze(0).to(DEVICE)
        
        model = self.models[model_type_lower]
        with torch.no_grad():
            with torch.amp.autocast('cuda' if DEVICE == 'cuda' else 'cpu'):
                output = model(input_tensor)
                probabilities = output.squeeze().cpu().to(torch.float32).numpy()

        # 1. Map the raw array to the order the model was trained on
        raw_traits = ['Extraversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']
        raw_results = {trait: float(score) for trait, score in zip(raw_traits, probabilities)}

        # 2. Standardize to the OCEAN format using Pydantic
        standardized_ocean = OCEANTraits(
            Openness=raw_results['Openness'],
            Conscientiousness=raw_results['Conscientiousness'],
            Extraversion=raw_results['Extraversion'],
            Agreeableness=raw_results['Agreeableness'],
            Neuroticism=raw_results['Neuroticism']
        )

        # 3. Return the strictly formatted Pydantic object
        return PredictionResponse(
            model_used=model_type_lower, 
            predictions=standardized_ocean,
            cropped_face_base64=cropped_base64
        )

# Global instance to be used across the application
model_manager = ModelManager()