Update main.py
Browse files
main.py
CHANGED
|
@@ -10,17 +10,15 @@ from transformers import (
|
|
| 10 |
AutoModelForSeq2SeqLM,
|
| 11 |
pipeline
|
| 12 |
)
|
| 13 |
-
from huggingface_hub import snapshot_download
|
| 14 |
from torch.quantization import quantize_dynamic
|
| 15 |
import logging
|
| 16 |
import ffmpeg
|
| 17 |
import tempfile
|
| 18 |
|
| 19 |
-
# Force HF cache to /tmp
|
| 20 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 21 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
| 22 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
| 23 |
-
|
| 24 |
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
| 25 |
|
| 26 |
# Silence all transformers and huggingface logging
|
|
@@ -31,22 +29,20 @@ logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
|
| 31 |
app = Flask(__name__)
|
| 32 |
CORS(app)
|
| 33 |
|
| 34 |
-
# ========== Load Whisper Model (quantized
|
| 35 |
-
def load_whisper_model(model_size="small"
|
| 36 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 37 |
model_name = f"openai/whisper-{model_size}"
|
| 38 |
-
processor = WhisperProcessor.from_pretrained(model_name
|
| 39 |
-
model = WhisperForConditionalGeneration.from_pretrained(model_name
|
| 40 |
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 41 |
model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
return processor, model
|
| 43 |
|
| 44 |
# ========== Load Grammar Correction Model (quantized) ==========
|
| 45 |
-
def load_grammar_model(
|
| 46 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 47 |
model_name = "prithivida/grammar_error_correcter_v1"
|
| 48 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name
|
| 49 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name
|
| 50 |
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 51 |
grammar_pipeline = pipeline(
|
| 52 |
"text2text-generation",
|
|
@@ -115,10 +111,8 @@ def correct_grammar(text, grammar_pipeline):
|
|
| 115 |
return '. '.join([r['generated_text'] for r in results])
|
| 116 |
|
| 117 |
# ========== Initialize Models ==========
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
processor = WhisperProcessor.from_pretrained(model_name)
|
| 121 |
-
model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
| 122 |
|
| 123 |
# ========== Warm-Up Models ==========
|
| 124 |
def warm_up_models():
|
|
@@ -158,4 +152,4 @@ def transcribe():
|
|
| 158 |
|
| 159 |
# ========== Run App ==========
|
| 160 |
if __name__ == '__main__':
|
| 161 |
-
app.run(host="0.0.0.0", debug=False, port=7860)
|
|
|
|
| 10 |
AutoModelForSeq2SeqLM,
|
| 11 |
pipeline
|
| 12 |
)
|
|
|
|
| 13 |
from torch.quantization import quantize_dynamic
|
| 14 |
import logging
|
| 15 |
import ffmpeg
|
| 16 |
import tempfile
|
| 17 |
|
| 18 |
+
# ========== Force HF cache to /tmp ==========
|
| 19 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 20 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
| 21 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
|
|
|
| 22 |
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
| 23 |
|
| 24 |
# Silence all transformers and huggingface logging
|
|
|
|
| 29 |
app = Flask(__name__)
|
| 30 |
CORS(app)
|
| 31 |
|
| 32 |
+
# ========== Load Whisper Model (quantized) ==========
|
| 33 |
+
def load_whisper_model(model_size="small"):
|
|
|
|
| 34 |
model_name = f"openai/whisper-{model_size}"
|
| 35 |
+
processor = WhisperProcessor.from_pretrained(model_name)
|
| 36 |
+
model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
| 37 |
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 38 |
model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
return processor, model
|
| 40 |
|
| 41 |
# ========== Load Grammar Correction Model (quantized) ==========
|
| 42 |
+
def load_grammar_model():
|
|
|
|
| 43 |
model_name = "prithivida/grammar_error_correcter_v1"
|
| 44 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 45 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 46 |
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 47 |
grammar_pipeline = pipeline(
|
| 48 |
"text2text-generation",
|
|
|
|
| 111 |
return '. '.join([r['generated_text'] for r in results])
|
| 112 |
|
| 113 |
# ========== Initialize Models ==========
|
| 114 |
+
processor, whisper_model = load_whisper_model("small")
|
| 115 |
+
grammar_pipeline = load_grammar_model()
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# ========== Warm-Up Models ==========
|
| 118 |
def warm_up_models():
|
|
|
|
| 152 |
|
| 153 |
# ========== Run App ==========
|
| 154 |
if __name__ == '__main__':
|
| 155 |
+
app.run(host="0.0.0.0", debug=False, port=7860)
|