Commit ·
365d5ef
1
Parent(s): 132a7e5
upload
Browse files- config.json +1 -1
- run_pretrain_no_trainer.py +12 -8
config.json
CHANGED
|
@@ -47,7 +47,7 @@
|
|
| 47 |
"feat_proj_dropout": 0.0,
|
| 48 |
"feat_quantizer_dropout": 0.0,
|
| 49 |
"final_dropout": 0.0,
|
| 50 |
-
"gradient_checkpointing":
|
| 51 |
"hidden_act": "gelu",
|
| 52 |
"hidden_dropout": 0.0,
|
| 53 |
"hidden_size": 1024,
|
|
|
|
| 47 |
"feat_proj_dropout": 0.0,
|
| 48 |
"feat_quantizer_dropout": 0.0,
|
| 49 |
"final_dropout": 0.0,
|
| 50 |
+
"gradient_checkpointing": false,
|
| 51 |
"hidden_act": "gelu",
|
| 52 |
"hidden_dropout": 0.0,
|
| 53 |
"hidden_size": 1024,
|
run_pretrain_no_trainer.py
CHANGED
|
@@ -1,20 +1,20 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
-
import argparse
|
| 3 |
-
import logging
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
from typing import Dict, List, Optional, Union
|
| 6 |
import os
|
| 7 |
-
|
| 8 |
import torch
|
| 9 |
import math
|
| 10 |
import datasets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from datasets import DatasetDict, load_dataset
|
|
|
|
|
|
|
| 12 |
from accelerate import Accelerator, DeepSpeedPlugin
|
| 13 |
from tqdm.auto import tqdm
|
| 14 |
from torch.utils.data.dataloader import DataLoader
|
| 15 |
-
|
| 16 |
-
import librosa
|
| 17 |
-
import transformers
|
| 18 |
from transformers import (
|
| 19 |
MODEL_MAPPING,
|
| 20 |
SchedulerType,
|
|
@@ -34,6 +34,9 @@ MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
|
| 34 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
| 37 |
def parse_args():
|
| 38 |
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
| 39 |
parser.add_argument(
|
|
@@ -505,6 +508,7 @@ def main():
|
|
| 505 |
for k, v in logs.items():
|
| 506 |
log_str += f"| {k}: {round(v.item(), 5)}"
|
| 507 |
|
|
|
|
| 508 |
progress_bar.write(log_str)
|
| 509 |
|
| 510 |
if completed_steps >= args.max_train_steps:
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import os
|
|
|
|
| 3 |
import torch
|
| 4 |
import math
|
| 5 |
import datasets
|
| 6 |
+
import wandb
|
| 7 |
+
import argparse
|
| 8 |
+
import logging
|
| 9 |
+
import librosa
|
| 10 |
+
import transformers
|
| 11 |
+
|
| 12 |
from datasets import DatasetDict, load_dataset
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Dict, List, Optional, Union
|
| 15 |
from accelerate import Accelerator, DeepSpeedPlugin
|
| 16 |
from tqdm.auto import tqdm
|
| 17 |
from torch.utils.data.dataloader import DataLoader
|
|
|
|
|
|
|
|
|
|
| 18 |
from transformers import (
|
| 19 |
MODEL_MAPPING,
|
| 20 |
SchedulerType,
|
|
|
|
| 34 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 35 |
|
| 36 |
|
| 37 |
+
wandb.init(project="pretraining-wav2vec2")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
def parse_args():
|
| 41 |
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
| 42 |
parser.add_argument(
|
|
|
|
| 508 |
for k, v in logs.items():
|
| 509 |
log_str += f"| {k}: {round(v.item(), 5)}"
|
| 510 |
|
| 511 |
+
wandb.log(logs)
|
| 512 |
progress_bar.write(log_str)
|
| 513 |
|
| 514 |
if completed_steps >= args.max_train_steps:
|