patrickvonplaten commited on
Commit
365d5ef
·
1 Parent(s): 132a7e5
Files changed (2) hide show
  1. config.json +1 -1
  2. run_pretrain_no_trainer.py +12 -8
config.json CHANGED
@@ -47,7 +47,7 @@
47
  "feat_proj_dropout": 0.0,
48
  "feat_quantizer_dropout": 0.0,
49
  "final_dropout": 0.0,
50
- "gradient_checkpointing": true,
51
  "hidden_act": "gelu",
52
  "hidden_dropout": 0.0,
53
  "hidden_size": 1024,
 
47
  "feat_proj_dropout": 0.0,
48
  "feat_quantizer_dropout": 0.0,
49
  "final_dropout": 0.0,
50
+ "gradient_checkpointing": false,
51
  "hidden_act": "gelu",
52
  "hidden_dropout": 0.0,
53
  "hidden_size": 1024,
run_pretrain_no_trainer.py CHANGED
@@ -1,20 +1,20 @@
1
  #!/usr/bin/env python3
2
- import argparse
3
- import logging
4
- from dataclasses import dataclass
5
- from typing import Dict, List, Optional, Union
6
  import os
7
-
8
  import torch
9
  import math
10
  import datasets
 
 
 
 
 
 
11
  from datasets import DatasetDict, load_dataset
 
 
12
  from accelerate import Accelerator, DeepSpeedPlugin
13
  from tqdm.auto import tqdm
14
  from torch.utils.data.dataloader import DataLoader
15
-
16
- import librosa
17
- import transformers
18
  from transformers import (
19
  MODEL_MAPPING,
20
  SchedulerType,
@@ -34,6 +34,9 @@ MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
34
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
35
 
36
 
 
 
 
37
  def parse_args():
38
  parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
39
  parser.add_argument(
@@ -505,6 +508,7 @@ def main():
505
  for k, v in logs.items():
506
  log_str += f"| {k}: {round(v.item(), 5)}"
507
 
 
508
  progress_bar.write(log_str)
509
 
510
  if completed_steps >= args.max_train_steps:
 
1
  #!/usr/bin/env python3
 
 
 
 
2
  import os
 
3
  import torch
4
  import math
5
  import datasets
6
+ import wandb
7
+ import argparse
8
+ import logging
9
+ import librosa
10
+ import transformers
11
+
12
  from datasets import DatasetDict, load_dataset
13
+ from dataclasses import dataclass
14
+ from typing import Dict, List, Optional, Union
15
  from accelerate import Accelerator, DeepSpeedPlugin
16
  from tqdm.auto import tqdm
17
  from torch.utils.data.dataloader import DataLoader
 
 
 
18
  from transformers import (
19
  MODEL_MAPPING,
20
  SchedulerType,
 
34
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
35
 
36
 
37
+ wandb.init(project="pretraining-wav2vec2")
38
+
39
+
40
  def parse_args():
41
  parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
42
  parser.add_argument(
 
508
  for k, v in logs.items():
509
  log_str += f"| {k}: {round(v.item(), 5)}"
510
 
511
+ wandb.log(logs)
512
  progress_bar.write(log_str)
513
 
514
  if completed_steps >= args.max_train_steps: