Fred808 commited on
Commit
9331c57
·
verified ·
1 Parent(s): 0a9ae38

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import json
4
+ from transformers import Trainer, TrainingArguments, Wav2Vec2ForCTC, Wav2Vec2Processor
5
+
6
+ # 1. Load Audio Data from JSON File
7
+ def load_audio_from_json(json_file):
8
+ with open(json_file, 'r') as f:
9
+ data = json.load(f)
10
+ audio_samples = []
11
+ for item in data['audio_files']:
12
+ if item.get('url'):
13
+ # Downloading from URL (requires additional handling if desired)
14
+ continue
15
+ audio, sr = torchaudio.load(item['path'])
16
+ audio_samples.append((audio, sr))
17
+ return audio_samples
18
+
19
+ audio_samples = load_audio_from_json('audio_data.json')
20
+
21
+ # 2. Load Pre-trained Model and Processor
22
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
23
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
24
+
25
+ # 3. Preprocess Data
26
+ def preprocess_audio(audio_sample):
27
+ audio, sr = audio_sample
28
+ inputs = processor(audio.numpy(), sampling_rate=sr, return_tensors="pt", padding=True)
29
+ return inputs.input_values[0], inputs.attention_mask[0]
30
+
31
+ dataset = [(preprocess_audio(sample)) for sample in audio_samples]
32
+
33
+ # 4. Training Arguments
34
+ training_args = TrainingArguments(
35
+ output_dir="./rvc_checkpoints",
36
+ evaluation_strategy="epoch",
37
+ per_device_train_batch_size=2,
38
+ per_device_eval_batch_size=2,
39
+ gradient_accumulation_steps=2,
40
+ num_train_epochs=3,
41
+ save_strategy="epoch",
42
+ logging_dir="./logs",
43
+ logging_steps=10,
44
+ report_to="none",
45
+ fp16=torch.cuda.is_available(),
46
+ )
47
+
48
+ # 5. Trainer Setup
49
+ trainer = Trainer(
50
+ model=model,
51
+ args=training_args,
52
+ train_dataset=dataset,
53
+ )
54
+
55
+ # 6. Train Model
56
+ trainer.train()
57
+
58
+ # 7. Save Model
59
+ model.save_pretrained("./rvc_trained_model")
60
+ processor.save_pretrained("./rvc_trained_model")
61
+
62
+ print("Training Completed!")