Spaces:
Running
Running
| import io | |
| import base64 | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from fastapi import HTTPException | |
| from services.inference import BigFiveRegressor | |
| from schemas.predict import OCEANTraits, PredictionResponse | |
| from services.face_extractor import FaceExtractor | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # DEVICE = "cpu" | |
| class ModelManager: | |
| def __init__(self): | |
| self.models = {} | |
| self.transforms_dict = {} | |
| self.model_configs = {} | |
| try: | |
| self.face_extractor = FaceExtractor() | |
| except Exception as e: | |
| print(f"Warning: Failed to initialize FaceExtractor: {e}") | |
| self.face_extractor = None | |
| def load_hf_model_pipeline(self, model_key: str, repo_id: str, model_info: dict = None): | |
| """Loads model from Hugging Face and creates its specific preprocessing transform.""" | |
| try: | |
| model = BigFiveRegressor.from_pretrained(repo_id) | |
| model.to(DEVICE) | |
| model.eval() | |
| # SwinV2 uses 256x256, ViT/PVTv2 use 224x224 | |
| IMG_SIZE = 256 if 'swinv2' in model_key else 224 | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| self.models[model_key] = model | |
| self.transforms_dict[model_key] = transform | |
| if model_info: | |
| self.model_configs[model_key] = model_info | |
| print(f"✅ Loaded {model_key.upper()} from {repo_id}") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load {model_key} from {repo_id}. Error: {e}") | |
| def predict(self, model_type: str, image_base64: str) -> PredictionResponse: | |
| model_type_lower = model_type.lower() | |
| if model_type_lower not in self.models: | |
| raise HTTPException(status_code=400, detail=f"Invalid model type. Choose from: {list(self.models.keys())}") | |
| # Decode Base64 to Image | |
| try: | |
| # Strip header if frontend accidentally includes "data:image/jpeg;base64," | |
| base64_data = image_base64.split(",")[-1] | |
| image_data = base64.b64decode(base64_data) | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid Base64 image payload.") | |
| # Face Extraction | |
| cropped_base64 = None | |
| if self.face_extractor: | |
| image = self.face_extractor.extract_main_face(image) | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="No face detected in the image.") | |
| # Convert back to base64 for response | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| # Transform and Infer | |
| transform = self.transforms_dict[model_type_lower] | |
| input_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| model = self.models[model_type_lower] | |
| with torch.no_grad(): | |
| with torch.amp.autocast('cuda' if DEVICE == 'cuda' else 'cpu'): | |
| output = model(input_tensor) | |
| probabilities = output.squeeze().cpu().to(torch.float32).numpy() | |
| # 1. Map the raw array to the order the model was trained on | |
| raw_traits = ['Extraversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness'] | |
| raw_results = {trait: float(score) for trait, score in zip(raw_traits, probabilities)} | |
| # 2. Standardize to the OCEAN format using Pydantic | |
| standardized_ocean = OCEANTraits( | |
| Openness=raw_results['Openness'], | |
| Conscientiousness=raw_results['Conscientiousness'], | |
| Extraversion=raw_results['Extraversion'], | |
| Agreeableness=raw_results['Agreeableness'], | |
| Neuroticism=raw_results['Neuroticism'] | |
| ) | |
| # 3. Return the strictly formatted Pydantic object | |
| return PredictionResponse( | |
| model_used=model_type_lower, | |
| predictions=standardized_ocean, | |
| cropped_face_base64=cropped_base64 | |
| ) | |
| # Global instance to be used across the application | |
| model_manager = ModelManager() | |