ms180 commited on
Commit
d641d01
·
verified ·
1 Parent(s): 5dcded9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -1
app.py CHANGED
@@ -63,7 +63,11 @@ s2l = Speech2Language.from_pretrained(
63
  device=device,
64
  nbest=1,
65
  )
66
- s2t_ar = ARSpeech2Text.from_pretrained(
 
 
 
 
67
  model_tag=f"espnet/owsm_v4_medium_1B",
68
  device=device,
69
  beam_size=5,
@@ -74,6 +78,51 @@ s2t_ar = ARSpeech2Text.from_pretrained(
74
  task_sym="<asr>",
75
  predict_time=False,
76
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  s2t_ctc = CTCSpeech2Text.from_pretrained(
78
  model_tag=f"espnet/owsm_ctc_v4_1B",
79
  device=device,
 
63
  device=device,
64
  nbest=1,
65
  )
66
+
67
+ # Hacking to cange config
68
+ # 1. download files
69
+ try:
70
+ s2t_ar = ARSpeech2Text.from_pretrained(
71
  model_tag=f"espnet/owsm_v4_medium_1B",
72
  device=device,
73
  beam_size=5,
 
78
  task_sym="<asr>",
79
  predict_time=False,
80
  )
81
+ except Exception as e:
82
+ print("File downloaded")
83
+
84
+ # 2. Remove unrequired file
85
+ import yaml
86
+ from pathlib import Path
87
+ import espnet_model_zoo
88
+
89
+ d = "models--espnet--owsm_v4_medium_1B/snapshots/471418ddaf0b03c9ab1fd75f1f5d26fc3aea3aa9/exp/s2t_train_conv2d8_size1024_e18_d18_mel128_raw_bpe50000/config.yaml"
90
+ p = Path(espnet_model_zoo.__file__)
91
+ config_path = p.parent / d
92
+
93
+ def remove_key(obj, key="gradient_checkpoint_layers"):
94
+ if isinstance(obj, dict):
95
+ if key in obj:
96
+ del obj[key]
97
+ for k, v in list(obj.items()):
98
+ remove_key(v, key)
99
+ elif isinstance(obj, list):
100
+ for item in obj:
101
+ remove_key(item, key)
102
+
103
+ with open(config_path, "r") as f:
104
+ config = yaml.safe_load(f)
105
+
106
+ remove_key(config)
107
+
108
+ with open(config_path, "w") as f:
109
+ yaml.safe_dump(config, f, sort_keys=False, allow_unicode=True)
110
+
111
+ print("Done! All 'gradient_checkpoint_layers' keys removed.")
112
+
113
+ s2t_ar = ARSpeech2Text.from_pretrained(
114
+ model_tag=f"espnet/owsm_v4_medium_1B",
115
+ device=device,
116
+ beam_size=5,
117
+ ctc_weight=0.0,
118
+ maxlenratio=0.0,
119
+ # below are default values which can be overwritten in __call__
120
+ lang_sym="<eng>",
121
+ task_sym="<asr>",
122
+ predict_time=False,
123
+ )
124
+
125
+ # CTC looks okay.
126
  s2t_ctc = CTCSpeech2Text.from_pretrained(
127
  model_tag=f"espnet/owsm_ctc_v4_1B",
128
  device=device,