import torch from torchvision import models from safetensors.torch import load_file from huggingface_hub import hf_hub_download import joblib # ---------------- GLOBAL MODELS ---------------- # cnn_model = None kmeans_model = None scaler = None class DenseNet121_CheXpert(torch.nn.Module): def __init__(self, num_labels=14): super().__init__() self.densenet = models.densenet121(weights=None) num_features = self.densenet.classifier.in_features self.densenet.classifier = torch.nn.Linear(num_features, num_labels) def forward(self, x): return self.densenet(x) # ---------------- LOAD FUNCTION ---------------- # def load_all_models(): global cnn_model, kmeans_model, scaler print("Downloading DenseNet...") local_path = hf_hub_download( repo_id="itsomk/chexpert-densenet121", filename="pytorch_model.safetensors" ) print("Loading CNN...") state = load_file(local_path) cnn_model = DenseNet121_CheXpert() cnn_model.load_state_dict(state, strict=False) cnn_model.eval() print("Loading KMeans + Scaler...") kmeans_model = joblib.load("models/risk_model.pkl") scaler = joblib.load("models/risk_scaler.pkl") print("ALL MODELS READY 🚀") # ---------------- AUTO LOAD (IMPORTANT FIX) ---------------- #