Spaces:
Running
Running
| """ | |
| Voting Predictor β Ensemble inference engine for image classification. | |
| Combines three models via weighted soft voting: | |
| - M1: DINOv3 ViT-Large (weight 4/7) β best single model, self-supervised features | |
| - M2: XGBoost on ResNet50 features (weight 1/7) β uncorrelated with neural nets | |
| - M3: EfficientNet-B0 (weight 2/7) β lightweight CNN for diversity | |
| XGBoost probabilities are "sharpened" (p^3 / sum(p^3)) to prevent | |
| its flat distributions from diluting confident predictions from DINOv3. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import xgboost as xgb | |
| import timm, joblib, numpy as np | |
| from pathlib import Path | |
| from torchvision import models | |
| import torch.nn as nn | |
| from src.features.build_features import preprocess_image | |
| class VotingPredictor: | |
| """Loads 3 image models and returns fused top-5 predictions.""" | |
| def __init__(self, models_dir): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.mdir = Path(models_dir) | |
| self.loaded = False | |
| def load_models(self): | |
| """Load all three models + label encoder into memory (lazy, once).""" | |
| if self.loaded: | |
| return | |
| # M1 β DINOv3 ViT-Large fine-tuned on 27 Rakuten classes | |
| self.m1 = timm.create_model( | |
| 'vit_large_patch14_reg4_dinov2.lvd142m', pretrained=False, num_classes=27 | |
| ) | |
| self.m1.load_state_dict( | |
| torch.load(self.mdir / "M1_IMAGE_DeepLearning_DINOv3.pth", map_location=self.device) | |
| ) | |
| # M2 β XGBoost classifier on ResNet50 features (2048-dim) | |
| self.has_xgboost = False | |
| xgb_path = self.mdir / "M2_IMAGE_Classic_XGBoost.json" | |
| if xgb_path.exists(): | |
| self.m2 = xgb.XGBClassifier() | |
| self.m2.load_model(str(xgb_path)) | |
| self.le = joblib.load(self.mdir / "M2_IMAGE_XGBoost_Encoder.pkl") | |
| # ResNet50 feature extractor (headless) β feeds M2 | |
| res = models.resnet50(weights=None) | |
| self.ext = nn.Sequential(*list(res.children())[:-1]) | |
| self.ext.to(self.device).eval() | |
| self.has_xgboost = True | |
| # M3 β EfficientNet-B0 with custom 27-class head | |
| self.m3 = models.efficientnet_b0(weights=None) | |
| self.m3.classifier[1] = nn.Linear(1280, 27) | |
| self.m3.load_state_dict( | |
| torch.load(self.mdir / "M3_IMAGE_Classic_EfficientNetB0.pth", map_location=self.device) | |
| ) | |
| # Move all torch models to device and set eval mode | |
| for m in [self.m1, self.m3]: | |
| m.to(self.device).eval() | |
| self.loaded = True | |
| def predict(self, img_p): | |
| """ | |
| Run voting inference on a single image. | |
| Args: | |
| img_p: Path to the image file. | |
| Returns: | |
| List of top-5 dicts: [{"label": str, "confidence": float}, ...] | |
| """ | |
| if not self.loaded: | |
| self.load_models() | |
| with torch.no_grad(): | |
| # M1 β DINOv3 (518x518 input) | |
| p1 = F.softmax( | |
| self.m1(preprocess_image(img_p, "dino").to(self.device)), dim=1 | |
| ).cpu().numpy()[0] | |
| # Shared 224x224 input for M2 and M3 | |
| i2 = preprocess_image(img_p, "standard").to(self.device) | |
| # M3 β EfficientNet-B0 | |
| p3 = F.softmax(self.m3(i2), dim=1).cpu().numpy()[0] | |
| if self.has_xgboost: | |
| # M2 β XGBoost on ResNet50 features | |
| f = self.ext(i2).squeeze().cpu().numpy().reshape(1, -1) | |
| raw_p2 = self.m2.predict_proba(f)[0] | |
| # Sharpening: raise to power 3, then renormalize. | |
| # This forces XGBoost to commit to a class instead of | |
| # spreading probability uniformly across all 27 classes. | |
| sharp_p2 = np.power(raw_p2, 3) | |
| p2 = sharp_p2 / sharp_p2.sum() | |
| # Weighted soft vote: DINOv3=4, EfficientNet=2, XGBoost=1 | |
| f_p = (4.0 * p1 + 1.0 * p2 + 2.0 * p3) / 7.0 | |
| else: | |
| # Fallback: DINOv3 + EfficientNet only (weights 4:2) | |
| f_p = (4.0 * p1 + 2.0 * p3) / 6.0 | |
| # Return top-5 predictions with original Rakuten category codes | |
| ids = np.argsort(f_p)[-5:][::-1] | |
| # Use label encoder if available (XGBoost), otherwise use class index | |
| if self.has_xgboost: | |
| labels = [str(self.le.inverse_transform([i])[0]) for i in ids] | |
| else: | |
| # Load category mapping to get actual Rakuten codes | |
| import json | |
| mapping_path = self.mdir / "category_mapping.json" | |
| if mapping_path.exists(): | |
| with open(mapping_path, 'r') as mf: | |
| cat_map = json.load(mf) | |
| code_list = sorted(cat_map.keys(), key=int) | |
| labels = [code_list[i] if i < len(code_list) else str(i) for i in ids] | |
| else: | |
| labels = [str(i) for i in ids] | |
| return [ | |
| {"label": labels[j], "confidence": float(f_p[ids[j]])} | |
| for j in range(len(ids)) | |
| ] | |