Spaces:
Sleeping
Sleeping
File size: 5,332 Bytes
e3bdc52 ea97e04 e3bdc52 ea97e04 e3bdc52 ea97e04 e3bdc52 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import os
import torch
import librosa
from torch.utils.data import Dataset
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, Trainer, TrainingArguments
from typing import Dict, List
# Define o modelo base usado pelo projeto
BASE_MODEL_NAME = "HyperMoon/wav2vec2-base-960h-finetuned-deepfake"
LOCAL_MODEL_DIR = "./local_finetuned_model"
def get_processor():
"""Retorna o extrator de características do modelo base (processador de áudio puro, sem tokenizador de texto)"""
return Wav2Vec2FeatureExtractor.from_pretrained(BASE_MODEL_NAME)
class DeepfakeAudioDataset(Dataset):
"""
Dataset Customizado do Pytorch para carregar áudios de Pastas.
Espera-se que o diretório base tenha duas subpastas: 'real' e 'fake'.
"""
def __init__(self, root_dir: str, processor: Wav2Vec2FeatureExtractor, max_length: int = 160000):
self.root_dir = root_dir
self.processor = processor
self.max_length = max_length
self.files: List[Dict] = []
self._load_metadata()
def _load_metadata(self):
real_dir = os.path.join(self.root_dir, 'real')
fake_dir = os.path.join(self.root_dir, 'fake')
if os.path.exists(real_dir):
for f in os.listdir(real_dir):
if f.lower().endswith(('.wav', '.mp3', '.flac')):
self.files.append({"path": os.path.join(real_dir, f), "label": 0})
if os.path.exists(fake_dir):
for f in os.listdir(fake_dir):
if f.lower().endswith(('.wav', '.mp3', '.flac')):
self.files.append({"path": os.path.join(fake_dir, f), "label": 1})
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
item = self.files[idx]
audio_path = item["path"]
label = item["label"]
# Load and resample audio to 16kHz
speech, _ = librosa.load(audio_path, sr=16000)
# Process audio to get input values
input_values = self.processor(
speech,
sampling_rate=16000,
return_tensors="pt",
padding="max_length",
max_length=self.max_length,
truncation=True
).input_values[0]
return {
"input_values": input_values,
"labels": torch.tensor(label, dtype=torch.long)
}
def start_finetuning(dataset_dir: str):
"""
Inicia o treinamento congelando as camadas base para evitar OOM e focar apenas na cabeça de classificação.
"""
processor = get_processor()
# Prepara os datasets com split de 80/20 para avaliação real
full_dataset = DeepfakeAudioDataset(dataset_dir, processor)
if len(full_dataset) < 10:
print("⚠️ Dataset muito pequeno. Usando todo o conjunto para treino e eval.")
train_dataset = full_dataset
eval_dataset = full_dataset
else:
train_size = int(0.8 * len(full_dataset))
eval_size = len(full_dataset) - train_size
train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_size, eval_size])
print(f"📊 Dataset dividido: {train_size} para treino, {eval_size} para avaliação.")
if len(train_dataset) == 0:
raise ValueError("Nenhum áudio encontrado no dataset.")
# Mapeamento explícito para evitar confusão de labels (0=Real, 1=Fraude)
id2label = {0: "AUTHENTIC", 1: "FAKE"}
label2id = {"AUTHENTIC": 0, "FAKE": 1}
# Carrega modelo e congela base
model = Wav2Vec2ForSequenceClassification.from_pretrained(
BASE_MODEL_NAME,
num_labels=2,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True
)
# Freeze feature extractor e a base do transformer para poupar memória e tempo (Adaptação para hardwares fracos)
if hasattr(model, 'freeze_feature_encoder'):
model.freeze_feature_encoder()
elif hasattr(model, 'freeze_feature_extractor'):
model.freeze_feature_extractor()
if hasattr(model, 'wav2vec2'):
for param in model.wav2vec2.parameters():
param.requires_grad = False
# Training args voltados para hardware modesto
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=5,
per_device_train_batch_size=2, # Batch muito pequeno para não estourar memória
gradient_accumulation_steps=4, # Acumula para dar efeito de batch=8
learning_rate=2e-5,
save_strategy="epoch",
logging_dir="./logs",
logging_steps=1,
remove_unused_columns=False,
report_to="none", # Evita erros de conexão com serviços externos de log
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset, # Agora usando o split real de 20%
)
trainer.train()
# Salva o modelo afinado
model.save_pretrained(LOCAL_MODEL_DIR)
processor.save_pretrained(LOCAL_MODEL_DIR)
return True
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
start_finetuning(sys.argv[1])
|