Spaces:
Paused
Paused
| import torch | |
| import torchaudio | |
| import json | |
| from transformers import Trainer, TrainingArguments, Wav2Vec2ForCTC, Wav2Vec2Processor | |
| # 1. Load Audio Data from JSON File | |
| def load_audio_from_json(json_file): | |
| with open(json_file, 'r') as f: | |
| data = json.load(f) | |
| audio_samples = [] | |
| for item in data['audio_files']: | |
| if item.get('url'): | |
| # Downloading from URL (requires additional handling if desired) | |
| continue | |
| audio, sr = torchaudio.load(item['path']) | |
| audio_samples.append((audio, sr)) | |
| return audio_samples | |
| audio_samples = load_audio_from_json('audio_data.json') | |
| # 2. Load Pre-trained Model and Processor | |
| processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
| model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
| # 3. Preprocess Data | |
| def preprocess_audio(audio_sample): | |
| audio, sr = audio_sample | |
| inputs = processor(audio.numpy(), sampling_rate=sr, return_tensors="pt", padding=True) | |
| return inputs.input_values[0], inputs.attention_mask[0] | |
| dataset = [(preprocess_audio(sample)) for sample in audio_samples] | |
| # 4. Training Arguments | |
| training_args = TrainingArguments( | |
| output_dir="./rvc_checkpoints", | |
| evaluation_strategy="epoch", | |
| per_device_train_batch_size=2, | |
| per_device_eval_batch_size=2, | |
| gradient_accumulation_steps=2, | |
| num_train_epochs=3, | |
| save_strategy="epoch", | |
| logging_dir="./logs", | |
| logging_steps=10, | |
| report_to="none", | |
| fp16=torch.cuda.is_available(), | |
| ) | |
| # 5. Trainer Setup | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| # 6. Train Model | |
| trainer.train() | |
| # 7. Save Model | |
| model.save_pretrained("./rvc_trained_model") | |
| processor.save_pretrained("./rvc_trained_model") | |
| print("Training Completed!") | |