Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import librosa | |
| from torch.utils.data import Dataset | |
| from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, Trainer, TrainingArguments | |
| from typing import Dict, List | |
| # Define o modelo base usado pelo projeto | |
| BASE_MODEL_NAME = "HyperMoon/wav2vec2-base-960h-finetuned-deepfake" | |
| LOCAL_MODEL_DIR = "./local_finetuned_model" | |
| def get_processor(): | |
| """Retorna o extrator de características do modelo base (processador de áudio puro, sem tokenizador de texto)""" | |
| return Wav2Vec2FeatureExtractor.from_pretrained(BASE_MODEL_NAME) | |
| class DeepfakeAudioDataset(Dataset): | |
| """ | |
| Dataset Customizado do Pytorch para carregar áudios de Pastas. | |
| Espera-se que o diretório base tenha duas subpastas: 'real' e 'fake'. | |
| """ | |
| def __init__(self, root_dir: str, processor: Wav2Vec2FeatureExtractor, max_length: int = 160000): | |
| self.root_dir = root_dir | |
| self.processor = processor | |
| self.max_length = max_length | |
| self.files: List[Dict] = [] | |
| self._load_metadata() | |
| def _load_metadata(self): | |
| real_dir = os.path.join(self.root_dir, 'real') | |
| fake_dir = os.path.join(self.root_dir, 'fake') | |
| if os.path.exists(real_dir): | |
| for f in os.listdir(real_dir): | |
| if f.lower().endswith(('.wav', '.mp3', '.flac')): | |
| self.files.append({"path": os.path.join(real_dir, f), "label": 0}) | |
| if os.path.exists(fake_dir): | |
| for f in os.listdir(fake_dir): | |
| if f.lower().endswith(('.wav', '.mp3', '.flac')): | |
| self.files.append({"path": os.path.join(fake_dir, f), "label": 1}) | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, idx): | |
| item = self.files[idx] | |
| audio_path = item["path"] | |
| label = item["label"] | |
| # Load and resample audio to 16kHz | |
| speech, _ = librosa.load(audio_path, sr=16000) | |
| # Process audio to get input values | |
| input_values = self.processor( | |
| speech, | |
| sampling_rate=16000, | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=self.max_length, | |
| truncation=True | |
| ).input_values[0] | |
| return { | |
| "input_values": input_values, | |
| "labels": torch.tensor(label, dtype=torch.long) | |
| } | |
| def start_finetuning(dataset_dir: str): | |
| """ | |
| Inicia o treinamento congelando as camadas base para evitar OOM e focar apenas na cabeça de classificação. | |
| """ | |
| processor = get_processor() | |
| # Prepara os datasets com split de 80/20 para avaliação real | |
| full_dataset = DeepfakeAudioDataset(dataset_dir, processor) | |
| if len(full_dataset) < 10: | |
| print("⚠️ Dataset muito pequeno. Usando todo o conjunto para treino e eval.") | |
| train_dataset = full_dataset | |
| eval_dataset = full_dataset | |
| else: | |
| train_size = int(0.8 * len(full_dataset)) | |
| eval_size = len(full_dataset) - train_size | |
| train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_size, eval_size]) | |
| print(f"📊 Dataset dividido: {train_size} para treino, {eval_size} para avaliação.") | |
| if len(train_dataset) == 0: | |
| raise ValueError("Nenhum áudio encontrado no dataset.") | |
| # Mapeamento explícito para evitar confusão de labels (0=Real, 1=Fraude) | |
| id2label = {0: "AUTHENTIC", 1: "FAKE"} | |
| label2id = {"AUTHENTIC": 0, "FAKE": 1} | |
| # Carrega modelo e congela base | |
| model = Wav2Vec2ForSequenceClassification.from_pretrained( | |
| BASE_MODEL_NAME, | |
| num_labels=2, | |
| id2label=id2label, | |
| label2id=label2id, | |
| ignore_mismatched_sizes=True | |
| ) | |
| # Freeze feature extractor e a base do transformer para poupar memória e tempo (Adaptação para hardwares fracos) | |
| if hasattr(model, 'freeze_feature_encoder'): | |
| model.freeze_feature_encoder() | |
| elif hasattr(model, 'freeze_feature_extractor'): | |
| model.freeze_feature_extractor() | |
| if hasattr(model, 'wav2vec2'): | |
| for param in model.wav2vec2.parameters(): | |
| param.requires_grad = False | |
| # Training args voltados para hardware modesto | |
| training_args = TrainingArguments( | |
| output_dir="./results", | |
| num_train_epochs=5, | |
| per_device_train_batch_size=2, # Batch muito pequeno para não estourar memória | |
| gradient_accumulation_steps=4, # Acumula para dar efeito de batch=8 | |
| learning_rate=2e-5, | |
| save_strategy="epoch", | |
| logging_dir="./logs", | |
| logging_steps=1, | |
| remove_unused_columns=False, | |
| report_to="none", # Evita erros de conexão com serviços externos de log | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, # Agora usando o split real de 20% | |
| ) | |
| trainer.train() | |
| # Salva o modelo afinado | |
| model.save_pretrained(LOCAL_MODEL_DIR) | |
| processor.save_pretrained(LOCAL_MODEL_DIR) | |
| return True | |
| if __name__ == "__main__": | |
| import sys | |
| if len(sys.argv) > 1: | |
| start_finetuning(sys.argv[1]) | |