import os import torch import librosa from torch.utils.data import Dataset from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, Trainer, TrainingArguments from typing import Dict, List # Define o modelo base usado pelo projeto BASE_MODEL_NAME = "HyperMoon/wav2vec2-base-960h-finetuned-deepfake" LOCAL_MODEL_DIR = "./local_finetuned_model" def get_processor(): """Retorna o extrator de características do modelo base (processador de áudio puro, sem tokenizador de texto)""" return Wav2Vec2FeatureExtractor.from_pretrained(BASE_MODEL_NAME) class DeepfakeAudioDataset(Dataset): """ Dataset Customizado do Pytorch para carregar áudios de Pastas. Espera-se que o diretório base tenha duas subpastas: 'real' e 'fake'. """ def __init__(self, root_dir: str, processor: Wav2Vec2FeatureExtractor, max_length: int = 160000): self.root_dir = root_dir self.processor = processor self.max_length = max_length self.files: List[Dict] = [] self._load_metadata() def _load_metadata(self): real_dir = os.path.join(self.root_dir, 'real') fake_dir = os.path.join(self.root_dir, 'fake') if os.path.exists(real_dir): for f in os.listdir(real_dir): if f.lower().endswith(('.wav', '.mp3', '.flac')): self.files.append({"path": os.path.join(real_dir, f), "label": 0}) if os.path.exists(fake_dir): for f in os.listdir(fake_dir): if f.lower().endswith(('.wav', '.mp3', '.flac')): self.files.append({"path": os.path.join(fake_dir, f), "label": 1}) def __len__(self): return len(self.files) def __getitem__(self, idx): item = self.files[idx] audio_path = item["path"] label = item["label"] # Load and resample audio to 16kHz speech, _ = librosa.load(audio_path, sr=16000) # Process audio to get input values input_values = self.processor( speech, sampling_rate=16000, return_tensors="pt", padding="max_length", max_length=self.max_length, truncation=True ).input_values[0] return { "input_values": input_values, "labels": torch.tensor(label, dtype=torch.long) } def start_finetuning(dataset_dir: str): """ Inicia o treinamento congelando as camadas base para evitar OOM e focar apenas na cabeça de classificação. """ processor = get_processor() # Prepara os datasets com split de 80/20 para avaliação real full_dataset = DeepfakeAudioDataset(dataset_dir, processor) if len(full_dataset) < 10: print("⚠️ Dataset muito pequeno. Usando todo o conjunto para treino e eval.") train_dataset = full_dataset eval_dataset = full_dataset else: train_size = int(0.8 * len(full_dataset)) eval_size = len(full_dataset) - train_size train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_size, eval_size]) print(f"📊 Dataset dividido: {train_size} para treino, {eval_size} para avaliação.") if len(train_dataset) == 0: raise ValueError("Nenhum áudio encontrado no dataset.") # Mapeamento explícito para evitar confusão de labels (0=Real, 1=Fraude) id2label = {0: "AUTHENTIC", 1: "FAKE"} label2id = {"AUTHENTIC": 0, "FAKE": 1} # Carrega modelo e congela base model = Wav2Vec2ForSequenceClassification.from_pretrained( BASE_MODEL_NAME, num_labels=2, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True ) # Freeze feature extractor e a base do transformer para poupar memória e tempo (Adaptação para hardwares fracos) if hasattr(model, 'freeze_feature_encoder'): model.freeze_feature_encoder() elif hasattr(model, 'freeze_feature_extractor'): model.freeze_feature_extractor() if hasattr(model, 'wav2vec2'): for param in model.wav2vec2.parameters(): param.requires_grad = False # Training args voltados para hardware modesto training_args = TrainingArguments( output_dir="./results", num_train_epochs=5, per_device_train_batch_size=2, # Batch muito pequeno para não estourar memória gradient_accumulation_steps=4, # Acumula para dar efeito de batch=8 learning_rate=2e-5, save_strategy="epoch", logging_dir="./logs", logging_steps=1, remove_unused_columns=False, report_to="none", # Evita erros de conexão com serviços externos de log ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, # Agora usando o split real de 20% ) trainer.train() # Salva o modelo afinado model.save_pretrained(LOCAL_MODEL_DIR) processor.save_pretrained(LOCAL_MODEL_DIR) return True if __name__ == "__main__": import sys if len(sys.argv) > 1: start_finetuning(sys.argv[1])