Vclone-221 / app.py
Fred808's picture
Create app.py
9331c57 verified
import torch
import torchaudio
import json
from transformers import Trainer, TrainingArguments, Wav2Vec2ForCTC, Wav2Vec2Processor
# 1. Load Audio Data from JSON File
def load_audio_from_json(json_file):
with open(json_file, 'r') as f:
data = json.load(f)
audio_samples = []
for item in data['audio_files']:
if item.get('url'):
# Downloading from URL (requires additional handling if desired)
continue
audio, sr = torchaudio.load(item['path'])
audio_samples.append((audio, sr))
return audio_samples
audio_samples = load_audio_from_json('audio_data.json')
# 2. Load Pre-trained Model and Processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# 3. Preprocess Data
def preprocess_audio(audio_sample):
audio, sr = audio_sample
inputs = processor(audio.numpy(), sampling_rate=sr, return_tensors="pt", padding=True)
return inputs.input_values[0], inputs.attention_mask[0]
dataset = [(preprocess_audio(sample)) for sample in audio_samples]
# 4. Training Arguments
training_args = TrainingArguments(
output_dir="./rvc_checkpoints",
evaluation_strategy="epoch",
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=2,
num_train_epochs=3,
save_strategy="epoch",
logging_dir="./logs",
logging_steps=10,
report_to="none",
fp16=torch.cuda.is_available(),
)
# 5. Trainer Setup
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
# 6. Train Model
trainer.train()
# 7. Save Model
model.save_pretrained("./rvc_trained_model")
processor.save_pretrained("./rvc_trained_model")
print("Training Completed!")