Spaces:
Sleeping
Sleeping
| import torch | |
| import traceback | |
| from transformers import AutoModelForAudioClassification, AutoFeatureExtractor | |
| from speechbrain.inference.VAD import VAD | |
| import os | |
| class ModelWrapper: | |
| def __init__(self, config: dict): | |
| self.config = config | |
| self.model_name = config.get("name", "nii-yamagishilab/mms-300m-anti-deepfake") | |
| self.device = config.get("device", "cpu") | |
| self.model = None | |
| self.feature_extractor = None | |
| self.vad = None | |
| # Log library versions for debugging | |
| try: | |
| import transformers | |
| import safetensors | |
| print(f"Library versions - transformers: {transformers.__version__}, safetensors: {safetensors.__version__}") | |
| except Exception as e: | |
| print(f"Warning: Could not log library versions: {e}") | |
| self.load_model() | |
| self.load_vad() | |
| def load_model(self): | |
| try: | |
| print(f"Loading Deepfake Detection model {self.model_name} on {self.device}...") | |
| model = AutoModelForAudioClassification.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True | |
| ).to(self.device) | |
| fe_name = self.config.get("feature_extractor", self.model_name) | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(fe_name) | |
| # Only set if both loaded successfully | |
| self.model = model | |
| self.feature_extractor = feature_extractor | |
| self.model.eval() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| print(f"Model name: {self.model_name}, Device: {self.device}") | |
| traceback.print_exc() | |
| self.model = None | |
| self.feature_extractor = None | |
| def load_vad(self): | |
| try: | |
| vad_repo = self.config.get("vad", {}).get("repo", "speechbrain/vad-crdnn-libriparty") | |
| print(f"Loading SpeechBrain VAD from {vad_repo}...") | |
| # VAD loads internal models, ensure we catch errors here too | |
| self.vad = VAD.from_hparams( | |
| source=vad_repo, | |
| savedir=self.config.get("vad", {}).get("save_path", "model_checkpoints") | |
| ) | |
| print("SpeechBrain VAD loaded.") | |
| except Exception as e: | |
| print(f"Error loading VAD: {e}") | |
| traceback.print_exc() | |
| # We can tolerate VAD failure slightly by processing whole audio, or fail hard. | |
| # For now, let's keep it robust. | |
| self.vad = None | |
| def predict(self, audio: torch.Tensor, sr: int) -> float: | |
| """ | |
| Predict probability of AI generation. | |
| Returns float (0.0 to 1.0), where 1.0 is AI. | |
| """ | |
| if self.model is None or self.feature_extractor is None: | |
| raise RuntimeError("Model not loaded") | |
| with torch.no_grad(): | |
| # Preprocess | |
| inputs = self.feature_extractor( | |
| audio.numpy(), | |
| sampling_rate=sr, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Inference | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| # Label mapping: | |
| # id2label usually {0: 'bonafide', 1: 'spoof'} OR {0: 'real', 1: 'fake'} | |
| # For mms-300m-anti-deepfake: 0 is 'bonafide' (human), 1 is 'spoof' (AI) | |
| # Verify this assumption via config or logs. | |
| # (Logs from repro script said: Labels: {0: 'LABEL_0', 1: 'LABEL_1'}) | |
| # Typically, LABEL_1 is the positive class (spoof). | |
| ai_prob = probs[0][1].item() | |
| # Safety check: handle NaN/Inf (can occur if model weights are improperly loaded) | |
| if not torch.isfinite(torch.tensor(ai_prob)): | |
| print(f"WARNING: Model returned non-finite value: {ai_prob}. Returning 0.5 as fallback.") | |
| return 0.5 | |
| return ai_prob | |