Spaces:
Sleeping
Sleeping
Replicate api logic
Browse files- .env +2 -0
- Dockerfile +18 -0
- README.md +1 -0
- holosubs.py +119 -0
- main.py +14 -0
- requirements.txt +16 -0
- transcribe.py +78 -0
- youtubeaudio.py +51 -0
.env
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
peft_model_id ="teoha/openai-whisper-medium-LORA-ja"
|
| 2 |
+
install_location = "/tmp/elite_understanding"
|
Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
RUN mkdir /.cache
|
| 5 |
+
RUN chmod 1777 /.cache
|
| 6 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 7 |
+
RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib/libcudart.so' >> ~/.bashrc
|
| 8 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 9 |
+
RUN /opt/conda/bin/pip install peft
|
| 10 |
+
RUN /opt/conda/bin/pip install -qq https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
|
| 11 |
+
# Expose the secret SECRET_EXAMPLE at buildtime and use its value as git remote URL
|
| 12 |
+
RUN --mount=type=secret,id=HUGGINGFACE_TOKEN,mode=0444,required=true \
|
| 13 |
+
huggingface-cli login --token $(cat /run/secrets/HUGGINGFACE_TOKEN) && \
|
| 14 |
+
echo "HUGGINGFACE_TOKEN=$( cat /run/secrets/HUGGINGFACE_TOKEN )" >> .env
|
| 15 |
+
|
| 16 |
+
COPY . .
|
| 17 |
+
|
| 18 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -6,6 +6,7 @@ colorTo: green
|
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
+
app_file: main.py
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
holosubs.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""""
|
| 2 |
+
Entry point and main execution block of the video transcription job
|
| 3 |
+
"""
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from youtubeaudio import YoutubeAudio
|
| 8 |
+
from transcribe import Transcriber
|
| 9 |
+
import torchaudio
|
| 10 |
+
from pyannote.audio import Pipeline
|
| 11 |
+
from webvtt import WebVTT, Caption
|
| 12 |
+
import torch
|
| 13 |
+
import logging
|
| 14 |
+
from huggingface_hub._login import _login
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
load_dotenv()
|
| 18 |
+
WHISPER_SAMPLE_RATE=16000
|
| 19 |
+
TIMESTAMP_PATTERN='[0-9]+:[0-9]+:[0-9]+\.[0-9]+'
|
| 20 |
+
MAX_CHUNK_DURATION=30000 # ms
|
| 21 |
+
|
| 22 |
+
format = "%(asctime)s: %(message)s"
|
| 23 |
+
logging.basicConfig(format=format, level=logging.DEBUG,
|
| 24 |
+
datefmt="%H:%M:%S")
|
| 25 |
+
_login(token=os.getenv('HUGGINGFACE_TOKEN'), add_to_git_credential=False)
|
| 26 |
+
|
| 27 |
+
def get_video_vtt(url) -> str:
|
| 28 |
+
# Download wav file
|
| 29 |
+
ytaudio=YoutubeAudio(url)
|
| 30 |
+
ytaudio.download_audio()
|
| 31 |
+
# Load audio
|
| 32 |
+
audio, sample_rate = torchaudio.load(ytaudio.filename)
|
| 33 |
+
audio_dict={"waveform": audio, "sample_rate": sample_rate}
|
| 34 |
+
# Diarization
|
| 35 |
+
pipeline = Pipeline.from_pretrained('pyannote/speaker-diarization@2.1', use_auth_token=True)
|
| 36 |
+
dzs = pipeline(audio_dict)
|
| 37 |
+
groups = group_segments(str(dzs).splitlines())
|
| 38 |
+
# Preprocess audio segments for translation
|
| 39 |
+
audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=WHISPER_SAMPLE_RATE)
|
| 40 |
+
audio_segments, timestamps = get_segments(groups, audio)
|
| 41 |
+
# Decoding audio segments into subtitles
|
| 42 |
+
transcriber = Transcriber(task="translate")
|
| 43 |
+
captions = decode_segments(audio_segments, timestamps, transcriber)
|
| 44 |
+
vtt = create_vtt(captions)
|
| 45 |
+
ytaudio.clean()
|
| 46 |
+
return vtt.content
|
| 47 |
+
|
| 48 |
+
def decode_segments(audio_segments, timestamps, transcriber):
|
| 49 |
+
captions = []
|
| 50 |
+
for i, segment in enumerate(audio_segments):
|
| 51 |
+
result = transcriber.decode(segment)
|
| 52 |
+
captions.append(Caption(timestamps[i][0], timestamps[i][1], result))
|
| 53 |
+
logging.info(f"Chunk output no.{i+1}: {result}")
|
| 54 |
+
return captions
|
| 55 |
+
|
| 56 |
+
def millisec(timeStr):
|
| 57 |
+
spl = timeStr.split(":")
|
| 58 |
+
s = (int)((int(spl[0]) * 60 * 60 + int(spl[1]) * 60 + float(spl[2]) )* 1000)
|
| 59 |
+
return s
|
| 60 |
+
|
| 61 |
+
def group_segments(dzs):
|
| 62 |
+
groups = []
|
| 63 |
+
g = []
|
| 64 |
+
lastend = 0
|
| 65 |
+
|
| 66 |
+
for d in dzs:
|
| 67 |
+
if g and (g[0].split()[-1] != d.split()[-1]): #same speaker
|
| 68 |
+
groups.append(g)
|
| 69 |
+
g = []
|
| 70 |
+
|
| 71 |
+
g.append(d)
|
| 72 |
+
|
| 73 |
+
end = re.findall('[0-9]+:[0-9]+:[0-9]+\.[0-9]+', string=d)[1]
|
| 74 |
+
end = millisec(end)
|
| 75 |
+
if (lastend > end): #segment engulfed by a previous segment
|
| 76 |
+
groups.append(g)
|
| 77 |
+
g = []
|
| 78 |
+
else:
|
| 79 |
+
lastend = end
|
| 80 |
+
if g:
|
| 81 |
+
groups.append(g)
|
| 82 |
+
logging.debug(groups)
|
| 83 |
+
return groups
|
| 84 |
+
|
| 85 |
+
def create_vtt(captions):
|
| 86 |
+
vtt = WebVTT()
|
| 87 |
+
for caption in captions:
|
| 88 |
+
vtt.captions.append(caption)
|
| 89 |
+
return vtt
|
| 90 |
+
# vtt.save(path)
|
| 91 |
+
|
| 92 |
+
def get_segments(groups, audio):
|
| 93 |
+
monoaudio=torch.mean(input=audio,dim=0).numpy()
|
| 94 |
+
audio_segments = []
|
| 95 |
+
timestamps = []
|
| 96 |
+
for g in groups:
|
| 97 |
+
cur_start_time, cur_end_time = re.findall(TIMESTAMP_PATTERN, string=g[0])
|
| 98 |
+
cur_start_millisec = millisec(cur_start_time) #- spacermilli
|
| 99 |
+
cur_end_millisec = millisec(cur_end_time) #- spacermilli
|
| 100 |
+
for window in g[1:]:
|
| 101 |
+
start_time, end_time = re.findall(TIMESTAMP_PATTERN, string=window)
|
| 102 |
+
start_millisec = millisec(start_time) #- spacermilli
|
| 103 |
+
end_millisec = millisec(end_time) #- spacermilli
|
| 104 |
+
# Check if new window exceeds chunk size
|
| 105 |
+
seg_duration_with_window=end_millisec-cur_start_millisec
|
| 106 |
+
if seg_duration_with_window>MAX_CHUNK_DURATION: # Segment with window exceeds max chunk duration
|
| 107 |
+
start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
|
| 108 |
+
audio_segments.append(monoaudio[start_frame:end_frame])
|
| 109 |
+
timestamps.append((cur_start_time, cur_end_time))
|
| 110 |
+
cur_start_time, cur_end_time = start_time, end_time
|
| 111 |
+
cur_start_millisec, cur_end_millisec = start_millisec, end_millisec
|
| 112 |
+
else:
|
| 113 |
+
cur_end_time=end_time
|
| 114 |
+
cur_end_millisec=end_millisec
|
| 115 |
+
# Final update
|
| 116 |
+
start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
|
| 117 |
+
audio_segments.append(monoaudio[start_frame:end_frame])
|
| 118 |
+
timestamps.append((cur_start_time, cur_end_time))
|
| 119 |
+
return audio_segments, timestamps
|
main.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from holosubs import get_video_vtt
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
class Url(BaseModel):
|
| 6 |
+
url: str
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
app = FastAPI()
|
| 10 |
+
|
| 11 |
+
@app.post("/captions/")
|
| 12 |
+
def read_root(url: Url):
|
| 13 |
+
vtt_captions = get_video_vtt(url.url)
|
| 14 |
+
return {"captions": vtt_captions}
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.74.*
|
| 2 |
+
requests==2.27.*
|
| 3 |
+
sentencepiece==0.1.*
|
| 4 |
+
uvicorn[standard]==0.17.*
|
| 5 |
+
numpy==1.24.4
|
| 6 |
+
pyannote.audio==1.1.2
|
| 7 |
+
pyannote.core==5.0.0
|
| 8 |
+
pyannote.database==5.0.1
|
| 9 |
+
pyannote.metrics==3.2.1
|
| 10 |
+
pyannote.pipeline==1.5.2
|
| 11 |
+
python-dotenv==1.0.0
|
| 12 |
+
torch==2.0.1
|
| 13 |
+
torchaudio==2.0.2
|
| 14 |
+
transformers==4.31.0
|
| 15 |
+
webvtt_py==0.4.6
|
| 16 |
+
yt_dlp==2023.7.6
|
transcribe.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Represents a model that transcribes and translates audio.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
from peft import PeftConfig, PeftModel
|
| 13 |
+
from transformers import (AutomaticSpeechRecognitionPipeline,
|
| 14 |
+
WhisperForConditionalGeneration, WhisperProcessor,
|
| 15 |
+
WhisperTokenizer)
|
| 16 |
+
|
| 17 |
+
load_dotenv()
|
| 18 |
+
format = "%(asctime)s: %(message)s"
|
| 19 |
+
logging.basicConfig(format=format, level=logging.DEBUG,
|
| 20 |
+
datefmt="%H:%M:%S")
|
| 21 |
+
|
| 22 |
+
class Transcriber:
|
| 23 |
+
def __init__(self, model_id="teoha/openai-whisper-medium-LORA-ja", language="Japanese", task="translate"):
|
| 24 |
+
self.language=language
|
| 25 |
+
self.task=task
|
| 26 |
+
peft_model_id = model_id if model_id else os.getenv('peft_model_id')
|
| 27 |
+
# TODO: Fix Download and install model locally
|
| 28 |
+
# self.install_model(peft_model_id)
|
| 29 |
+
self.initialize_pipe(peft_model_id) #initialize pipe
|
| 30 |
+
|
| 31 |
+
def install_model(self, peft_model_id:str) -> None:
|
| 32 |
+
save_location = os.path.join(os.getenv('install_location'), peft_model_id)
|
| 33 |
+
offload_location = os.path.join(os.getenv('install_location'), "offload")
|
| 34 |
+
#Save Model
|
| 35 |
+
peft_config = PeftConfig.from_pretrained(peft_model_id)
|
| 36 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
| 37 |
+
peft_config.base_model_name_or_path,
|
| 38 |
+
load_in_8bit=False, device_map="auto"
|
| 39 |
+
)
|
| 40 |
+
model = PeftModel.from_pretrained(model, peft_model_id, offload_folder="offload_location")
|
| 41 |
+
model.save_pretrained(save_location)
|
| 42 |
+
|
| 43 |
+
#Save tokenizer/processor
|
| 44 |
+
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
|
| 45 |
+
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
|
| 46 |
+
tokenizer.save_pretrained(save_location)
|
| 47 |
+
processor.save_pretrained(save_location)
|
| 48 |
+
logging.info("Installation Completed successfully")
|
| 49 |
+
|
| 50 |
+
def initialize_pipe(self, peft_model_id: str) -> None:
|
| 51 |
+
offload_location = os.path.join(os.getenv('install_location'), "offload")
|
| 52 |
+
# Initalize model configs
|
| 53 |
+
peft_config = PeftConfig.from_pretrained(peft_model_id)
|
| 54 |
+
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, load_in_8bit=False, device_map="auto")
|
| 55 |
+
model = PeftModel.from_pretrained(model, peft_model_id, offload_folder=offload_location)
|
| 56 |
+
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
|
| 57 |
+
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
|
| 58 |
+
feature_extractor = processor.feature_extractor
|
| 59 |
+
# Initialize class variables
|
| 60 |
+
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=self.language, task=self.task)
|
| 61 |
+
self.pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 62 |
+
logging.info("Pipe successfully initialized")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def decode(self, audio: Union[np.ndarray, bytes, str]) -> str:
|
| 66 |
+
'''
|
| 67 |
+
Transcribes a sequence of floats representing an audio snippet.
|
| 68 |
+
Args:
|
| 69 |
+
inputs (:obj:`np.ndarray` or :obj:`bytes` or :obj:`str`):
|
| 70 |
+
The inputs is either a raw waveform (:obj:`np.ndarray` of shape (n, ) of type :obj:`np.float32` or
|
| 71 |
+
:obj:`np.float64`) at the correct sampling rate (no further check will be done) or a :obj:`str` that is
|
| 72 |
+
the filename of the audio file, the file will be read at the correct sampling rate to get the waveform
|
| 73 |
+
using `ffmpeg`. This requires `ffmpeg` to be installed on the system. If `inputs` is :obj:`bytes` it is
|
| 74 |
+
supposed to be the content of an audio file and is interpreted by `ffmpeg` in the same way.
|
| 75 |
+
'''
|
| 76 |
+
with torch.cuda.amp.autocast():
|
| 77 |
+
text = self.pipe(audio, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids})["text"]
|
| 78 |
+
return text
|
youtubeaudio.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Represents a Youtube video
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
import logging
|
| 7 |
+
from yt_dlp import YoutubeDL
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
format = "%(asctime)s: %(message)s"
|
| 13 |
+
logging.basicConfig(format=format, level=logging.DEBUG,
|
| 14 |
+
datefmt="%H:%M:%S")
|
| 15 |
+
|
| 16 |
+
class YoutubeAudio:
|
| 17 |
+
def __init__(self, url, dir="/tmp/holosubs/audio"):
|
| 18 |
+
self.url=url
|
| 19 |
+
self.dir=dir
|
| 20 |
+
|
| 21 |
+
def download_audio(self):
|
| 22 |
+
ydl_opts = {
|
| 23 |
+
'outtmpl': os.path.join(self.dir, "%(id)s_%(epoch)s.%(ext)s"),
|
| 24 |
+
'logger': logging,
|
| 25 |
+
'progress_hooks': [self.progress_hook],
|
| 26 |
+
'format': 'm4a/bestaudio/best',
|
| 27 |
+
'postprocessors': [{ # Extract audio using ffmpeg
|
| 28 |
+
'key': 'FFmpegExtractAudio',
|
| 29 |
+
'preferredcodec': 'wav',
|
| 30 |
+
}]
|
| 31 |
+
}
|
| 32 |
+
with YoutubeDL(ydl_opts) as ydl:
|
| 33 |
+
error_code = ydl.download([self.url])
|
| 34 |
+
|
| 35 |
+
def clean(self):
|
| 36 |
+
if not self.filename:
|
| 37 |
+
logging.error("Audio not downloaded")
|
| 38 |
+
return
|
| 39 |
+
location=os.path.join(self.dir, self.filename)
|
| 40 |
+
if os.path.exists(self.filename):
|
| 41 |
+
os.remove(self.filename)
|
| 42 |
+
logging.info(f"File {self.filename} successfully removed")
|
| 43 |
+
self.filename=None
|
| 44 |
+
else:
|
| 45 |
+
print(f"File {self.filename} does not exist")
|
| 46 |
+
|
| 47 |
+
def progress_hook(self, d):
|
| 48 |
+
if d['status'] == 'finished':
|
| 49 |
+
self.filename=os.path.join(self.dir, Path(d.get('info_dict').get('_filename')).stem + ".wav")
|
| 50 |
+
print(f'Done downloading {self.filename}, now post-processing ...')
|
| 51 |
+
|