File size: 1,800 Bytes
9331c57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torchaudio
import json
from transformers import Trainer, TrainingArguments, Wav2Vec2ForCTC, Wav2Vec2Processor

# 1. Load Audio Data from JSON File
def load_audio_from_json(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    audio_samples = []
    for item in data['audio_files']:
        if item.get('url'):
            # Downloading from URL (requires additional handling if desired)
            continue
        audio, sr = torchaudio.load(item['path'])
        audio_samples.append((audio, sr))
    return audio_samples

audio_samples = load_audio_from_json('audio_data.json')

# 2. Load Pre-trained Model and Processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

# 3. Preprocess Data
def preprocess_audio(audio_sample):
    audio, sr = audio_sample
    inputs = processor(audio.numpy(), sampling_rate=sr, return_tensors="pt", padding=True)
    return inputs.input_values[0], inputs.attention_mask[0]

dataset = [(preprocess_audio(sample)) for sample in audio_samples]

# 4. Training Arguments
training_args = TrainingArguments(
    output_dir="./rvc_checkpoints",
    evaluation_strategy="epoch",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    report_to="none",
    fp16=torch.cuda.is_available(),
)

# 5. Trainer Setup
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

# 6. Train Model
trainer.train()

# 7. Save Model
model.save_pretrained("./rvc_trained_model")
processor.save_pretrained("./rvc_trained_model")

print("Training Completed!")