AbdoIR commited on
Commit
48d8acc
·
verified ·
1 Parent(s): a703706

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +11 -17
main.py CHANGED
@@ -10,17 +10,15 @@ from transformers import (
10
  AutoModelForSeq2SeqLM,
11
  pipeline
12
  )
13
- from huggingface_hub import snapshot_download
14
  from torch.quantization import quantize_dynamic
15
  import logging
16
  import ffmpeg
17
  import tempfile
18
 
19
- # Force HF cache to /tmp
20
  os.environ["HF_HOME"] = "/tmp/huggingface"
21
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
22
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
23
-
24
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
25
 
26
  # Silence all transformers and huggingface logging
@@ -31,22 +29,20 @@ logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
31
  app = Flask(__name__)
32
  CORS(app)
33
 
34
- # ========== Load Whisper Model (quantized + small) ==========
35
- def load_whisper_model(model_size="small", save_dir="/tmp/models_cache/whisper"):
36
- os.makedirs(save_dir, exist_ok=True)
37
  model_name = f"openai/whisper-{model_size}"
38
- processor = WhisperProcessor.from_pretrained(model_name, cache_dir=save_dir)
39
- model = WhisperForConditionalGeneration.from_pretrained(model_name, cache_dir=save_dir)
40
  model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
41
  model.to("cuda" if torch.cuda.is_available() else "cpu")
42
  return processor, model
43
 
44
  # ========== Load Grammar Correction Model (quantized) ==========
45
- def load_grammar_model(save_dir="/tmp/models_cache/grammar_corrector"):
46
- os.makedirs(save_dir, exist_ok=True)
47
  model_name = "prithivida/grammar_error_correcter_v1"
48
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=save_dir)
49
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=save_dir)
50
  model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
51
  grammar_pipeline = pipeline(
52
  "text2text-generation",
@@ -115,10 +111,8 @@ def correct_grammar(text, grammar_pipeline):
115
  return '. '.join([r['generated_text'] for r in results])
116
 
117
  # ========== Initialize Models ==========
118
- # processor, whisper_model = load_whisper_model("small")
119
- # grammar_pipeline = load_grammar_model()
120
- processor = WhisperProcessor.from_pretrained(model_name)
121
- model = WhisperForConditionalGeneration.from_pretrained(model_name)
122
 
123
  # ========== Warm-Up Models ==========
124
  def warm_up_models():
@@ -158,4 +152,4 @@ def transcribe():
158
 
159
  # ========== Run App ==========
160
  if __name__ == '__main__':
161
- app.run(host="0.0.0.0", debug=False, port=7860)
 
10
  AutoModelForSeq2SeqLM,
11
  pipeline
12
  )
 
13
  from torch.quantization import quantize_dynamic
14
  import logging
15
  import ffmpeg
16
  import tempfile
17
 
18
+ # ========== Force HF cache to /tmp ==========
19
  os.environ["HF_HOME"] = "/tmp/huggingface"
20
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
21
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
 
22
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
23
 
24
  # Silence all transformers and huggingface logging
 
29
  app = Flask(__name__)
30
  CORS(app)
31
 
32
+ # ========== Load Whisper Model (quantized) ==========
33
+ def load_whisper_model(model_size="small"):
 
34
  model_name = f"openai/whisper-{model_size}"
35
+ processor = WhisperProcessor.from_pretrained(model_name)
36
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
37
  model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
38
  model.to("cuda" if torch.cuda.is_available() else "cpu")
39
  return processor, model
40
 
41
  # ========== Load Grammar Correction Model (quantized) ==========
42
+ def load_grammar_model():
 
43
  model_name = "prithivida/grammar_error_correcter_v1"
44
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
45
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
46
  model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
47
  grammar_pipeline = pipeline(
48
  "text2text-generation",
 
111
  return '. '.join([r['generated_text'] for r in results])
112
 
113
  # ========== Initialize Models ==========
114
+ processor, whisper_model = load_whisper_model("small")
115
+ grammar_pipeline = load_grammar_model()
 
 
116
 
117
  # ========== Warm-Up Models ==========
118
  def warm_up_models():
 
152
 
153
  # ========== Run App ==========
154
  if __name__ == '__main__':
155
+ app.run(host="0.0.0.0", debug=False, port=7860)