teoha commited on
Commit
56aa734
·
1 Parent(s): 08a889b

use Whisper Cpp

Browse files
.env DELETED
@@ -1,2 +0,0 @@
1
- peft_model_id ="teoha/openai-whisper-medium-LORA-ja"
2
- install_location = "/tmp/elite_understanding"
 
 
 
Dockerfile CHANGED
@@ -1,15 +1,33 @@
1
- FROM pytorch/pytorch
2
-
3
- WORKDIR /code
4
- RUN mkdir /.cache
5
- RUN chmod 1777 /.cache
6
- COPY ./requirements.txt /code/requirements.txt
7
- RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib/libcudart.so' >> ~/.bashrc
8
- RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
9
- RUN /opt/conda/bin/pip install peft
10
- RUN /opt/conda/bin/pip install -qq https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
11
- RUN --mount=type=secret,id=HUGGINGFACE_TOKEN,mode=0444,required=true \
12
- huggingface-cli login --token $(cat /run/secrets/HUGGINGFACE_TOKEN) && \
13
- echo "HUGGINGFACE_TOKEN=$( cat /run/secrets/HUGGINGFACE_TOKEN )" >> .env
14
- COPY . .
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
+ FROM python:3.10 as build
2
+ RUN apt-get update
3
+ RUN mkdir /whisper && \
4
+ wget -q https://github.com/ggerganov/whisper.cpp/tarball/master -O - | \
5
+ tar -xz -C /whisper --strip-components 1
6
+
7
+ WORKDIR /whisper
8
+
9
+ ARG model
10
+ RUN bash ./models/download-ggml-model.sh "${model}"
11
+ RUN make main
12
+
13
+ FROM python:3.10 as whisper
14
+
15
+ RUN apt-get update \
16
+ && apt-get install -y libsdl2-dev alsa-utils ffmpeg \
17
+ && apt-get clean \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ COPY requirements.txt /
21
+ RUN pip install -r /requirements.txt
22
+ COPY main.py /root
23
+ COPY youtubeaudio.py /root
24
+
25
+ WORKDIR /root
26
+ ARG model
27
+ ENV model=$model
28
+ RUN mkdir /root/models
29
+ RUN mkdir -p -m 777 /tmp/holosubs/results
30
+ COPY --from=build "/whisper/models/ggml-${model}.bin" "/root/models/ggml-${model}.bin"
31
+ COPY --from=build /whisper/main /usr/local/bin/whisper
32
+
33
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
__pycache__/youtubeaudio.cpython-310.pyc ADDED
Binary file (2.42 kB). View file
 
docker-compose.yml DELETED
@@ -1,8 +0,0 @@
1
- services:
2
- holosubs:
3
- image: holosubs
4
- container_name: local_holosubs
5
- ports:
6
- - "7860:7860"
7
- volumes:
8
- - ~/.cache:/.cache
 
 
 
 
 
 
 
 
 
holosubs.py DELETED
@@ -1,119 +0,0 @@
1
- """"
2
- Entry point and main execution block of the video transcription job
3
- """
4
- import re
5
-
6
- from dotenv import load_dotenv
7
- from youtubeaudio import YoutubeAudio
8
- from transcribe import Transcriber
9
- import torchaudio
10
- from pyannote.audio import Pipeline
11
- from webvtt import WebVTT, Caption
12
- import torch
13
- import logging
14
- from huggingface_hub._login import _login
15
- import os
16
-
17
- load_dotenv()
18
- WHISPER_SAMPLE_RATE=16000
19
- TIMESTAMP_PATTERN='[0-9]+:[0-9]+:[0-9]+\.[0-9]+'
20
- MAX_CHUNK_DURATION=30000 # ms
21
-
22
- format = "%(asctime)s: %(message)s"
23
- logging.basicConfig(format=format, level=logging.DEBUG,
24
- datefmt="%H:%M:%S")
25
- _login(token=os.getenv('HUGGINGFACE_TOKEN'), add_to_git_credential=False)
26
-
27
- def get_video_vtt(url) -> str:
28
- # Download wav file
29
- ytaudio=YoutubeAudio(url)
30
- ytaudio.download_audio()
31
- # Load audio
32
- audio, sample_rate = torchaudio.load(ytaudio.filename)
33
- audio_dict={"waveform": audio, "sample_rate": sample_rate}
34
- # Diarization
35
- pipeline = Pipeline.from_pretrained('pyannote/speaker-diarization@2.1', use_auth_token=True)
36
- dzs = pipeline(audio_dict)
37
- groups = group_segments(str(dzs).splitlines())
38
- # Preprocess audio segments for translation
39
- audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=WHISPER_SAMPLE_RATE)
40
- audio_segments, timestamps = get_segments(groups, audio)
41
- # Decoding audio segments into subtitles
42
- transcriber = Transcriber(task="translate")
43
- captions = decode_segments(audio_segments, timestamps, transcriber)
44
- vtt = create_vtt(captions)
45
- ytaudio.clean()
46
- return vtt.content
47
-
48
- def decode_segments(audio_segments, timestamps, transcriber):
49
- captions = []
50
- for i, segment in enumerate(audio_segments):
51
- result = transcriber.decode(segment)
52
- captions.append(Caption(timestamps[i][0], timestamps[i][1], result))
53
- logging.info(f"Chunk output no.{i+1}: {result}")
54
- return captions
55
-
56
- def millisec(timeStr):
57
- spl = timeStr.split(":")
58
- s = (int)((int(spl[0]) * 60 * 60 + int(spl[1]) * 60 + float(spl[2]) )* 1000)
59
- return s
60
-
61
- def group_segments(dzs):
62
- groups = []
63
- g = []
64
- lastend = 0
65
-
66
- for d in dzs:
67
- if g and (g[0].split()[-1] != d.split()[-1]): #same speaker
68
- groups.append(g)
69
- g = []
70
-
71
- g.append(d)
72
-
73
- end = re.findall('[0-9]+:[0-9]+:[0-9]+\.[0-9]+', string=d)[1]
74
- end = millisec(end)
75
- if (lastend > end): #segment engulfed by a previous segment
76
- groups.append(g)
77
- g = []
78
- else:
79
- lastend = end
80
- if g:
81
- groups.append(g)
82
- logging.debug(groups)
83
- return groups
84
-
85
- def create_vtt(captions):
86
- vtt = WebVTT()
87
- for caption in captions:
88
- vtt.captions.append(caption)
89
- return vtt
90
- # vtt.save(path)
91
-
92
- def get_segments(groups, audio):
93
- monoaudio=torch.mean(input=audio,dim=0).numpy()
94
- audio_segments = []
95
- timestamps = []
96
- for g in groups:
97
- cur_start_time, cur_end_time = re.findall(TIMESTAMP_PATTERN, string=g[0])
98
- cur_start_millisec = millisec(cur_start_time) #- spacermilli
99
- cur_end_millisec = millisec(cur_end_time) #- spacermilli
100
- for window in g[1:]:
101
- start_time, end_time = re.findall(TIMESTAMP_PATTERN, string=window)
102
- start_millisec = millisec(start_time) #- spacermilli
103
- end_millisec = millisec(end_time) #- spacermilli
104
- # Check if new window exceeds chunk size
105
- seg_duration_with_window=end_millisec-cur_start_millisec
106
- if seg_duration_with_window>MAX_CHUNK_DURATION: # Segment with window exceeds max chunk duration
107
- start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
108
- audio_segments.append(monoaudio[start_frame:end_frame])
109
- timestamps.append((cur_start_time, cur_end_time))
110
- cur_start_time, cur_end_time = start_time, end_time
111
- cur_start_millisec, cur_end_millisec = start_millisec, end_millisec
112
- else:
113
- cur_end_time=end_time
114
- cur_end_millisec=end_millisec
115
- # Final update
116
- start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
117
- audio_segments.append(monoaudio[start_frame:end_frame])
118
- timestamps.append((cur_start_time, cur_end_time))
119
- return audio_segments, timestamps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,14 +1,43 @@
1
- from fastapi import FastAPI
2
- from holosubs import get_video_vtt
3
  from pydantic import BaseModel
 
 
 
 
 
 
4
 
5
  class Url(BaseModel):
6
  url: str
 
7
 
 
 
 
8
 
9
  app = FastAPI()
10
 
11
  @app.post("/captions/")
12
  def read_root(url: Url):
13
- vtt_captions = get_video_vtt(url.url)
14
- return {"captions": vtt_captions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
+ from youtubeaudio import YoutubeAudio
4
+ import subprocess
5
+ import os
6
+ import uuid
7
+ import logging
8
+ import ffmpeg
9
 
10
  class Url(BaseModel):
11
  url: str
12
+ format = "%(asctime)s: %(message)s"
13
 
14
+ logging.basicConfig(format=format, level=logging.DEBUG,
15
+ datefmt="%H:%M:%S")
16
+ MODEL=os.environ['model']
17
 
18
  app = FastAPI()
19
 
20
  @app.post("/captions/")
21
  def read_root(url: Url):
22
+ # Download wav file and get filename
23
+ ytaudio=YoutubeAudio(url)
24
+ ytaudio.download_audio()
25
+ filename=ytaudio.filename
26
+ # Resample file
27
+ ytaudio.resample('16k')
28
+ # Generate subtitles
29
+ output_file=os.path.join("/tmp/holosubs/results", str(uuid.uuid4()))
30
+ logging.info(f'Writing to file {output_file}.vtt')
31
+ cmd=['/usr/local/bin/whisper','-m',f'/root/models/ggml-{MODEL}.bin'
32
+ ,'-f',filename, '-di', '-of', output_file, '-tr', '-ovtt', '-t', '8']
33
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
34
+ (output, err) = p.communicate()
35
+ p_status = p.wait()
36
+ logging.info(output)
37
+ if err:
38
+ logging.error("Whisper translation failed with error",err)
39
+ raise HTTPException(status_code=500, detail="Whisper translation failed")
40
+ with open(output_file+".vtt", 'r') as f:
41
+ raw_vtt=f.read()
42
+ os.remove(output_file+".vtt")
43
+ return {"captions": raw_vtt}
requirements.txt CHANGED
@@ -1,16 +1,5 @@
1
  fastapi==0.74.*
2
  requests==2.27.*
3
- sentencepiece==0.1.*
4
  uvicorn[standard]==0.17.*
5
- numpy==1.24.4
6
- pyannote.audio==1.1.2
7
- pyannote.core==5.0.0
8
- pyannote.database==5.0.1
9
- pyannote.metrics==3.2.1
10
- pyannote.pipeline==1.5.2
11
- python-dotenv==1.0.0
12
- torch==2.0.1
13
- torchaudio==2.0.2
14
- transformers==4.31.0
15
- webvtt_py==0.4.6
16
- yt_dlp==2023.7.6
 
1
  fastapi==0.74.*
2
  requests==2.27.*
 
3
  uvicorn[standard]==0.17.*
4
+ yt_dlp==2023.7.6
5
+ ffmpeg-python==0.2.0
 
 
 
 
 
 
 
 
 
 
transcribe.py DELETED
@@ -1,78 +0,0 @@
1
- """
2
- Represents a model that transcribes and translates audio.
3
- """
4
-
5
- import logging
6
- import os
7
- from typing import Union
8
-
9
- import numpy as np
10
- import torch
11
- from dotenv import load_dotenv
12
- from peft import PeftConfig, PeftModel
13
- from transformers import (AutomaticSpeechRecognitionPipeline,
14
- WhisperForConditionalGeneration, WhisperProcessor,
15
- WhisperTokenizer)
16
-
17
- load_dotenv()
18
- format = "%(asctime)s: %(message)s"
19
- logging.basicConfig(format=format, level=logging.DEBUG,
20
- datefmt="%H:%M:%S")
21
-
22
- class Transcriber:
23
- def __init__(self, model_id="teoha/openai-whisper-medium-LORA-ja", language="Japanese", task="translate"):
24
- self.language=language
25
- self.task=task
26
- peft_model_id = model_id if model_id else os.getenv('peft_model_id')
27
- # TODO: Fix Download and install model locally
28
- # self.install_model(peft_model_id)
29
- self.initialize_pipe(peft_model_id) #initialize pipe
30
-
31
- def install_model(self, peft_model_id:str) -> None:
32
- save_location = os.path.join(os.getenv('install_location'), peft_model_id)
33
- offload_location = os.path.join(os.getenv('install_location'), "offload")
34
- #Save Model
35
- peft_config = PeftConfig.from_pretrained(peft_model_id)
36
- model = WhisperForConditionalGeneration.from_pretrained(
37
- peft_config.base_model_name_or_path,
38
- load_in_8bit=False, device_map="auto"
39
- )
40
- model = PeftModel.from_pretrained(model, peft_model_id, offload_folder="offload_location")
41
- model.save_pretrained(save_location)
42
-
43
- #Save tokenizer/processor
44
- tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
45
- processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
46
- tokenizer.save_pretrained(save_location)
47
- processor.save_pretrained(save_location)
48
- logging.info("Installation Completed successfully")
49
-
50
- def initialize_pipe(self, peft_model_id: str) -> None:
51
- offload_location = os.path.join(os.getenv('install_location'), "offload")
52
- # Initalize model configs
53
- peft_config = PeftConfig.from_pretrained(peft_model_id)
54
- model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, load_in_8bit=False, device_map="auto")
55
- model = PeftModel.from_pretrained(model, peft_model_id, offload_folder=offload_location)
56
- tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
57
- processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
58
- feature_extractor = processor.feature_extractor
59
- # Initialize class variables
60
- self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=self.language, task=self.task)
61
- self.pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
62
- logging.info("Pipe successfully initialized")
63
-
64
-
65
- def decode(self, audio: Union[np.ndarray, bytes, str]) -> str:
66
- '''
67
- Transcribes a sequence of floats representing an audio snippet.
68
- Args:
69
- inputs (:obj:`np.ndarray` or :obj:`bytes` or :obj:`str`):
70
- The inputs is either a raw waveform (:obj:`np.ndarray` of shape (n, ) of type :obj:`np.float32` or
71
- :obj:`np.float64`) at the correct sampling rate (no further check will be done) or a :obj:`str` that is
72
- the filename of the audio file, the file will be read at the correct sampling rate to get the waveform
73
- using `ffmpeg`. This requires `ffmpeg` to be installed on the system. If `inputs` is :obj:`bytes` it is
74
- supposed to be the content of an audio file and is interpreted by `ffmpeg` in the same way.
75
- '''
76
- with torch.cuda.amp.autocast():
77
- text = self.pipe(audio, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids})["text"]
78
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
youtubeaudio.py CHANGED
@@ -2,13 +2,14 @@
2
  Represents a Youtube video
3
  """
4
 
5
- from dotenv import load_dotenv
6
  import logging
 
 
7
  from yt_dlp import YoutubeDL
8
  import os
9
  from pathlib import Path
 
10
 
11
- load_dotenv()
12
  format = "%(asctime)s: %(message)s"
13
  logging.basicConfig(format=format, level=logging.DEBUG,
14
  datefmt="%H:%M:%S")
@@ -17,7 +18,7 @@ class YoutubeAudio:
17
  def __init__(self, url, dir="/tmp/holosubs/audio"):
18
  self.url=url
19
  self.dir=dir
20
-
21
  def download_audio(self):
22
  ydl_opts = {
23
  'outtmpl': os.path.join(self.dir, "%(id)s_%(epoch)s.%(ext)s"),
@@ -26,12 +27,18 @@ class YoutubeAudio:
26
  'format': 'm4a/bestaudio/best',
27
  'postprocessors': [{ # Extract audio using ffmpeg
28
  'key': 'FFmpegExtractAudio',
29
- 'preferredcodec': 'wav',
30
  }]
31
  }
32
  with YoutubeDL(ydl_opts) as ydl:
33
- error_code = ydl.download([self.url])
34
 
 
 
 
 
 
 
35
  def clean(self):
36
  if not self.filename:
37
  logging.error("Audio not downloaded")
 
2
  Represents a Youtube video
3
  """
4
 
 
5
  import logging
6
+ import shutil
7
+ import uuid
8
  from yt_dlp import YoutubeDL
9
  import os
10
  from pathlib import Path
11
+ import ffmpeg
12
 
 
13
  format = "%(asctime)s: %(message)s"
14
  logging.basicConfig(format=format, level=logging.DEBUG,
15
  datefmt="%H:%M:%S")
 
18
  def __init__(self, url, dir="/tmp/holosubs/audio"):
19
  self.url=url
20
  self.dir=dir
21
+
22
  def download_audio(self):
23
  ydl_opts = {
24
  'outtmpl': os.path.join(self.dir, "%(id)s_%(epoch)s.%(ext)s"),
 
27
  'format': 'm4a/bestaudio/best',
28
  'postprocessors': [{ # Extract audio using ffmpeg
29
  'key': 'FFmpegExtractAudio',
30
+ 'preferredcodec': 'wav'
31
  }]
32
  }
33
  with YoutubeDL(ydl_opts) as ydl:
34
+ error_code = ydl.download([self.url.url])
35
 
36
+ def resample(self,sr='16k'):
37
+ tmp_filename=os.path.join(self.dir,str(uuid.uuid4()))+".wav"
38
+ ffmpeg.input(self.filename).output(tmp_filename,ar=sr).run()
39
+ shutil.move(tmp_filename, self.filename)
40
+ logging.info(f"Succesfuly resampled {self.filename}")
41
+
42
  def clean(self):
43
  if not self.filename:
44
  logging.error("Audio not downloaded")