Commit
·
4befcb0
1
Parent(s):
22c1f87
push up batching for transcription (A10 GPU)
Browse files- handler.py +1 -1
- test endpoint.ipynb +3 -2
handler.py
CHANGED
|
@@ -31,7 +31,7 @@ SAMPLE_RATE = 16000
|
|
| 31 |
def whisper_config():
|
| 32 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
whisper_model = "large-v3"
|
| 34 |
-
batch_size =
|
| 35 |
compute_type = "float16" if device == "cuda" else "int8"
|
| 36 |
return device, batch_size, compute_type, whisper_model
|
| 37 |
|
|
|
|
| 31 |
def whisper_config():
|
| 32 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
whisper_model = "large-v3"
|
| 34 |
+
batch_size = 48 if device == "cuda" else 1
|
| 35 |
compute_type = "float16" if device == "cuda" else "int8"
|
| 36 |
return device, batch_size, compute_type, whisper_model
|
| 37 |
|
test endpoint.ipynb
CHANGED
|
@@ -2,13 +2,13 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [],
|
| 8 |
"source": [
|
| 9 |
"from pathlib import Path\n",
|
| 10 |
"from retry import retry\n",
|
| 11 |
-
"import base64, requests, ffmpeg # ffmpeg-python\n",
|
| 12 |
"\n",
|
| 13 |
"token = \"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\"\n",
|
| 14 |
"API_URL = \"https://t4vtvikeag4f1yzd.eu-west-1.aws.endpoints.huggingface.cloud\"\n",
|
|
@@ -100,6 +100,7 @@
|
|
| 100 |
"else:\n",
|
| 101 |
" print(\"Transcription first...\")\n",
|
| 102 |
" response = query_transcription(audio_data=audio_data)\n",
|
|
|
|
| 103 |
" print(\"...then Diarization.\")\n",
|
| 104 |
" response = query_diarization(audio_data=audio_data, transcription=response)"
|
| 105 |
]
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [],
|
| 8 |
"source": [
|
| 9 |
"from pathlib import Path\n",
|
| 10 |
"from retry import retry\n",
|
| 11 |
+
"import time, base64, requests, ffmpeg # ffmpeg-python\n",
|
| 12 |
"\n",
|
| 13 |
"token = \"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\"\n",
|
| 14 |
"API_URL = \"https://t4vtvikeag4f1yzd.eu-west-1.aws.endpoints.huggingface.cloud\"\n",
|
|
|
|
| 100 |
"else:\n",
|
| 101 |
" print(\"Transcription first...\")\n",
|
| 102 |
" response = query_transcription(audio_data=audio_data)\n",
|
| 103 |
+
" time.sleep(30)\n",
|
| 104 |
" print(\"...then Diarization.\")\n",
|
| 105 |
" response = query_diarization(audio_data=audio_data, transcription=response)"
|
| 106 |
]
|