Spaces:
Build error
Build error
| import torch | |
| from torchvision import models | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| import joblib | |
| # ---------------- GLOBAL MODELS ---------------- # | |
| cnn_model = None | |
| kmeans_model = None | |
| scaler = None | |
| class DenseNet121_CheXpert(torch.nn.Module): | |
| def __init__(self, num_labels=14): | |
| super().__init__() | |
| self.densenet = models.densenet121(weights=None) | |
| num_features = self.densenet.classifier.in_features | |
| self.densenet.classifier = torch.nn.Linear(num_features, num_labels) | |
| def forward(self, x): | |
| return self.densenet(x) | |
| # ---------------- LOAD FUNCTION ---------------- # | |
| def load_all_models(): | |
| global cnn_model, kmeans_model, scaler | |
| print("Downloading DenseNet...") | |
| local_path = hf_hub_download( | |
| repo_id="itsomk/chexpert-densenet121", | |
| filename="pytorch_model.safetensors" | |
| ) | |
| print("Loading CNN...") | |
| state = load_file(local_path) | |
| cnn_model = DenseNet121_CheXpert() | |
| cnn_model.load_state_dict(state, strict=False) | |
| cnn_model.eval() | |
| print("Loading KMeans + Scaler...") | |
| kmeans_model = joblib.load("models/risk_model.pkl") | |
| scaler = joblib.load("models/risk_scaler.pkl") | |
| print("ALL MODELS READY 🚀") | |
| # ---------------- AUTO LOAD (IMPORTANT FIX) ---------------- # | |