teoha commited on
Commit
2592a7b
·
1 Parent(s): d973984

Replicate api logic

Browse files
Files changed (8) hide show
  1. .env +2 -0
  2. Dockerfile +18 -0
  3. README.md +1 -0
  4. holosubs.py +119 -0
  5. main.py +14 -0
  6. requirements.txt +16 -0
  7. transcribe.py +78 -0
  8. youtubeaudio.py +51 -0
.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ peft_model_id ="teoha/openai-whisper-medium-LORA-ja"
2
+ install_location = "/tmp/elite_understanding"
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch
2
+
3
+ WORKDIR /code
4
+ RUN mkdir /.cache
5
+ RUN chmod 1777 /.cache
6
+ COPY ./requirements.txt /code/requirements.txt
7
+ RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib/libcudart.so' >> ~/.bashrc
8
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
9
+ RUN /opt/conda/bin/pip install peft
10
+ RUN /opt/conda/bin/pip install -qq https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
11
+ # Expose the secret SECRET_EXAMPLE at buildtime and use its value as git remote URL
12
+ RUN --mount=type=secret,id=HUGGINGFACE_TOKEN,mode=0444,required=true \
13
+ huggingface-cli login --token $(cat /run/secrets/HUGGINGFACE_TOKEN) && \
14
+ echo "HUGGINGFACE_TOKEN=$( cat /run/secrets/HUGGINGFACE_TOKEN )" >> .env
15
+
16
+ COPY . .
17
+
18
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -6,6 +6,7 @@ colorTo: green
6
  sdk: docker
7
  pinned: false
8
  license: mit
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
+ app_file: main.py
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
holosubs.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """"
2
+ Entry point and main execution block of the video transcription job
3
+ """
4
+ import re
5
+
6
+ from dotenv import load_dotenv
7
+ from youtubeaudio import YoutubeAudio
8
+ from transcribe import Transcriber
9
+ import torchaudio
10
+ from pyannote.audio import Pipeline
11
+ from webvtt import WebVTT, Caption
12
+ import torch
13
+ import logging
14
+ from huggingface_hub._login import _login
15
+ import os
16
+
17
+ load_dotenv()
18
+ WHISPER_SAMPLE_RATE=16000
19
+ TIMESTAMP_PATTERN='[0-9]+:[0-9]+:[0-9]+\.[0-9]+'
20
+ MAX_CHUNK_DURATION=30000 # ms
21
+
22
+ format = "%(asctime)s: %(message)s"
23
+ logging.basicConfig(format=format, level=logging.DEBUG,
24
+ datefmt="%H:%M:%S")
25
+ _login(token=os.getenv('HUGGINGFACE_TOKEN'), add_to_git_credential=False)
26
+
27
+ def get_video_vtt(url) -> str:
28
+ # Download wav file
29
+ ytaudio=YoutubeAudio(url)
30
+ ytaudio.download_audio()
31
+ # Load audio
32
+ audio, sample_rate = torchaudio.load(ytaudio.filename)
33
+ audio_dict={"waveform": audio, "sample_rate": sample_rate}
34
+ # Diarization
35
+ pipeline = Pipeline.from_pretrained('pyannote/speaker-diarization@2.1', use_auth_token=True)
36
+ dzs = pipeline(audio_dict)
37
+ groups = group_segments(str(dzs).splitlines())
38
+ # Preprocess audio segments for translation
39
+ audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=WHISPER_SAMPLE_RATE)
40
+ audio_segments, timestamps = get_segments(groups, audio)
41
+ # Decoding audio segments into subtitles
42
+ transcriber = Transcriber(task="translate")
43
+ captions = decode_segments(audio_segments, timestamps, transcriber)
44
+ vtt = create_vtt(captions)
45
+ ytaudio.clean()
46
+ return vtt.content
47
+
48
+ def decode_segments(audio_segments, timestamps, transcriber):
49
+ captions = []
50
+ for i, segment in enumerate(audio_segments):
51
+ result = transcriber.decode(segment)
52
+ captions.append(Caption(timestamps[i][0], timestamps[i][1], result))
53
+ logging.info(f"Chunk output no.{i+1}: {result}")
54
+ return captions
55
+
56
+ def millisec(timeStr):
57
+ spl = timeStr.split(":")
58
+ s = (int)((int(spl[0]) * 60 * 60 + int(spl[1]) * 60 + float(spl[2]) )* 1000)
59
+ return s
60
+
61
+ def group_segments(dzs):
62
+ groups = []
63
+ g = []
64
+ lastend = 0
65
+
66
+ for d in dzs:
67
+ if g and (g[0].split()[-1] != d.split()[-1]): #same speaker
68
+ groups.append(g)
69
+ g = []
70
+
71
+ g.append(d)
72
+
73
+ end = re.findall('[0-9]+:[0-9]+:[0-9]+\.[0-9]+', string=d)[1]
74
+ end = millisec(end)
75
+ if (lastend > end): #segment engulfed by a previous segment
76
+ groups.append(g)
77
+ g = []
78
+ else:
79
+ lastend = end
80
+ if g:
81
+ groups.append(g)
82
+ logging.debug(groups)
83
+ return groups
84
+
85
+ def create_vtt(captions):
86
+ vtt = WebVTT()
87
+ for caption in captions:
88
+ vtt.captions.append(caption)
89
+ return vtt
90
+ # vtt.save(path)
91
+
92
+ def get_segments(groups, audio):
93
+ monoaudio=torch.mean(input=audio,dim=0).numpy()
94
+ audio_segments = []
95
+ timestamps = []
96
+ for g in groups:
97
+ cur_start_time, cur_end_time = re.findall(TIMESTAMP_PATTERN, string=g[0])
98
+ cur_start_millisec = millisec(cur_start_time) #- spacermilli
99
+ cur_end_millisec = millisec(cur_end_time) #- spacermilli
100
+ for window in g[1:]:
101
+ start_time, end_time = re.findall(TIMESTAMP_PATTERN, string=window)
102
+ start_millisec = millisec(start_time) #- spacermilli
103
+ end_millisec = millisec(end_time) #- spacermilli
104
+ # Check if new window exceeds chunk size
105
+ seg_duration_with_window=end_millisec-cur_start_millisec
106
+ if seg_duration_with_window>MAX_CHUNK_DURATION: # Segment with window exceeds max chunk duration
107
+ start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
108
+ audio_segments.append(monoaudio[start_frame:end_frame])
109
+ timestamps.append((cur_start_time, cur_end_time))
110
+ cur_start_time, cur_end_time = start_time, end_time
111
+ cur_start_millisec, cur_end_millisec = start_millisec, end_millisec
112
+ else:
113
+ cur_end_time=end_time
114
+ cur_end_millisec=end_millisec
115
+ # Final update
116
+ start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
117
+ audio_segments.append(monoaudio[start_frame:end_frame])
118
+ timestamps.append((cur_start_time, cur_end_time))
119
+ return audio_segments, timestamps
main.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from holosubs import get_video_vtt
3
+ from pydantic import BaseModel
4
+
5
+ class Url(BaseModel):
6
+ url: str
7
+
8
+
9
+ app = FastAPI()
10
+
11
+ @app.post("/captions/")
12
+ def read_root(url: Url):
13
+ vtt_captions = get_video_vtt(url.url)
14
+ return {"captions": vtt_captions}
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.74.*
2
+ requests==2.27.*
3
+ sentencepiece==0.1.*
4
+ uvicorn[standard]==0.17.*
5
+ numpy==1.24.4
6
+ pyannote.audio==1.1.2
7
+ pyannote.core==5.0.0
8
+ pyannote.database==5.0.1
9
+ pyannote.metrics==3.2.1
10
+ pyannote.pipeline==1.5.2
11
+ python-dotenv==1.0.0
12
+ torch==2.0.1
13
+ torchaudio==2.0.2
14
+ transformers==4.31.0
15
+ webvtt_py==0.4.6
16
+ yt_dlp==2023.7.6
transcribe.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Represents a model that transcribes and translates audio.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ from typing import Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from dotenv import load_dotenv
12
+ from peft import PeftConfig, PeftModel
13
+ from transformers import (AutomaticSpeechRecognitionPipeline,
14
+ WhisperForConditionalGeneration, WhisperProcessor,
15
+ WhisperTokenizer)
16
+
17
+ load_dotenv()
18
+ format = "%(asctime)s: %(message)s"
19
+ logging.basicConfig(format=format, level=logging.DEBUG,
20
+ datefmt="%H:%M:%S")
21
+
22
+ class Transcriber:
23
+ def __init__(self, model_id="teoha/openai-whisper-medium-LORA-ja", language="Japanese", task="translate"):
24
+ self.language=language
25
+ self.task=task
26
+ peft_model_id = model_id if model_id else os.getenv('peft_model_id')
27
+ # TODO: Fix Download and install model locally
28
+ # self.install_model(peft_model_id)
29
+ self.initialize_pipe(peft_model_id) #initialize pipe
30
+
31
+ def install_model(self, peft_model_id:str) -> None:
32
+ save_location = os.path.join(os.getenv('install_location'), peft_model_id)
33
+ offload_location = os.path.join(os.getenv('install_location'), "offload")
34
+ #Save Model
35
+ peft_config = PeftConfig.from_pretrained(peft_model_id)
36
+ model = WhisperForConditionalGeneration.from_pretrained(
37
+ peft_config.base_model_name_or_path,
38
+ load_in_8bit=False, device_map="auto"
39
+ )
40
+ model = PeftModel.from_pretrained(model, peft_model_id, offload_folder="offload_location")
41
+ model.save_pretrained(save_location)
42
+
43
+ #Save tokenizer/processor
44
+ tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
45
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
46
+ tokenizer.save_pretrained(save_location)
47
+ processor.save_pretrained(save_location)
48
+ logging.info("Installation Completed successfully")
49
+
50
+ def initialize_pipe(self, peft_model_id: str) -> None:
51
+ offload_location = os.path.join(os.getenv('install_location'), "offload")
52
+ # Initalize model configs
53
+ peft_config = PeftConfig.from_pretrained(peft_model_id)
54
+ model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, load_in_8bit=False, device_map="auto")
55
+ model = PeftModel.from_pretrained(model, peft_model_id, offload_folder=offload_location)
56
+ tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
57
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
58
+ feature_extractor = processor.feature_extractor
59
+ # Initialize class variables
60
+ self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=self.language, task=self.task)
61
+ self.pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
62
+ logging.info("Pipe successfully initialized")
63
+
64
+
65
+ def decode(self, audio: Union[np.ndarray, bytes, str]) -> str:
66
+ '''
67
+ Transcribes a sequence of floats representing an audio snippet.
68
+ Args:
69
+ inputs (:obj:`np.ndarray` or :obj:`bytes` or :obj:`str`):
70
+ The inputs is either a raw waveform (:obj:`np.ndarray` of shape (n, ) of type :obj:`np.float32` or
71
+ :obj:`np.float64`) at the correct sampling rate (no further check will be done) or a :obj:`str` that is
72
+ the filename of the audio file, the file will be read at the correct sampling rate to get the waveform
73
+ using `ffmpeg`. This requires `ffmpeg` to be installed on the system. If `inputs` is :obj:`bytes` it is
74
+ supposed to be the content of an audio file and is interpreted by `ffmpeg` in the same way.
75
+ '''
76
+ with torch.cuda.amp.autocast():
77
+ text = self.pipe(audio, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids})["text"]
78
+ return text
youtubeaudio.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Represents a Youtube video
3
+ """
4
+
5
+ from dotenv import load_dotenv
6
+ import logging
7
+ from yt_dlp import YoutubeDL
8
+ import os
9
+ from pathlib import Path
10
+
11
+ load_dotenv()
12
+ format = "%(asctime)s: %(message)s"
13
+ logging.basicConfig(format=format, level=logging.DEBUG,
14
+ datefmt="%H:%M:%S")
15
+
16
+ class YoutubeAudio:
17
+ def __init__(self, url, dir="/tmp/holosubs/audio"):
18
+ self.url=url
19
+ self.dir=dir
20
+
21
+ def download_audio(self):
22
+ ydl_opts = {
23
+ 'outtmpl': os.path.join(self.dir, "%(id)s_%(epoch)s.%(ext)s"),
24
+ 'logger': logging,
25
+ 'progress_hooks': [self.progress_hook],
26
+ 'format': 'm4a/bestaudio/best',
27
+ 'postprocessors': [{ # Extract audio using ffmpeg
28
+ 'key': 'FFmpegExtractAudio',
29
+ 'preferredcodec': 'wav',
30
+ }]
31
+ }
32
+ with YoutubeDL(ydl_opts) as ydl:
33
+ error_code = ydl.download([self.url])
34
+
35
+ def clean(self):
36
+ if not self.filename:
37
+ logging.error("Audio not downloaded")
38
+ return
39
+ location=os.path.join(self.dir, self.filename)
40
+ if os.path.exists(self.filename):
41
+ os.remove(self.filename)
42
+ logging.info(f"File {self.filename} successfully removed")
43
+ self.filename=None
44
+ else:
45
+ print(f"File {self.filename} does not exist")
46
+
47
+ def progress_hook(self, d):
48
+ if d['status'] == 'finished':
49
+ self.filename=os.path.join(self.dir, Path(d.get('info_dict').get('_filename')).stem + ".wav")
50
+ print(f'Done downloading {self.filename}, now post-processing ...')
51
+