timgremore commited on
Commit
7494a1c
·
1 Parent(s): 93ae114

feat: Try ligher weight and more accurate whisper

Browse files
Files changed (2) hide show
  1. fly.toml +1 -1
  2. lib/medicode/serving_supervisor.ex +33 -3
fly.toml CHANGED
@@ -24,7 +24,7 @@ kill_signal = 'SIGTERM'
24
  [[mounts]]
25
  source = 'data'
26
  destination = '/data'
27
- initial_size = '100gb'
28
 
29
  [http_service]
30
  internal_port = 8080
 
24
  [[mounts]]
25
  source = 'data'
26
  destination = '/data'
27
+ initial_size = '40gb'
28
 
29
  [http_service]
30
  internal_port = 8080
lib/medicode/serving_supervisor.ex CHANGED
@@ -7,7 +7,8 @@ defmodule Medicode.ServingSupervisor do
7
 
8
  alias AudioTagger.{KeywordFinder, Transcriber, Vectors}
9
 
10
- @model_name "openai/whisper-small"
 
11
  @question_answer_model_name "distilbert-base-cased-distilled-squad"
12
 
13
  def start_link(init_arg) do
@@ -20,14 +21,20 @@ defmodule Medicode.ServingSupervisor do
20
  transcription_spec(),
21
  token_classification_spec(),
22
  text_embedding_spec(),
23
- question_answer_spec(),
24
  ]
25
 
26
  Supervisor.init(children, strategy: :one_for_one)
27
  end
28
 
29
  defp transcription_spec do
30
- Transcriber.child_spec(Medicode.TranscriptionServing, @model_name)
 
 
 
 
 
 
31
  end
32
 
33
  defp token_classification_spec do
@@ -53,4 +60,27 @@ defmodule Medicode.ServingSupervisor do
53
  serving: serving, name: Medicode.QAServing, batch_size: 1, batch_timeout: 100
54
  }
55
  end
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  end
 
7
 
8
  alias AudioTagger.{KeywordFinder, Transcriber, Vectors}
9
 
10
+ # @model_name "openai/whisper-small"
11
+ @model_name "distil-whisper/distil-medium.en"
12
  @question_answer_model_name "distilbert-base-cased-distilled-squad"
13
 
14
  def start_link(init_arg) do
 
21
  transcription_spec(),
22
  token_classification_spec(),
23
  text_embedding_spec(),
24
+ question_answer_spec()
25
  ]
26
 
27
  Supervisor.init(children, strategy: :one_for_one)
28
  end
29
 
30
  defp transcription_spec do
31
+ {:ok, featurizer} = Bumblebee.load_featurizer({:hf, @model_name})
32
+ serving = serving_with_featurizer(featurizer, @model_name)
33
+
34
+ {
35
+ Nx.Serving,
36
+ serving: serving, name: Medicode.TranscriptionServing, batch_size: 4, batch_timeout: 100
37
+ }
38
  end
39
 
40
  defp token_classification_spec do
 
60
  serving: serving, name: Medicode.QAServing, batch_size: 1, batch_timeout: 100
61
  }
62
  end
63
+
64
+ @doc "Creates an Nx.Serving to perform speech-to-text tasks, using the passed featurizer. This is helpful for direct use from Livebook where the featurizer is needed to define the Kino audio input."
65
+ def serving_with_featurizer(featurizer, model_name) do
66
+ {:ok, model_info} = Bumblebee.load_model({:hf, model_name})
67
+ {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
68
+ {:ok, generation_config} = Bumblebee.load_generation_config({:hf, model_name})
69
+ generation_config = Bumblebee.configure(generation_config, max_new_tokens: 100)
70
+
71
+ # Docs: https://hexdocs.pm/bumblebee/Bumblebee.Audio.html#speech_to_text_whisper/5
72
+ Bumblebee.Audio.speech_to_text_whisper(
73
+ model_info,
74
+ featurizer,
75
+ tokenizer,
76
+ generation_config,
77
+ task: nil,
78
+ compile: [batch_size: 4],
79
+ chunk_num_seconds: 30,
80
+ # context_num_seconds: 5, # Defaults to 1/6 of :chunk_num_seconds
81
+ timestamps: :segments,
82
+ stream: true,
83
+ defn_options: [compiler: EXLA]
84
+ )
85
+ end
86
  end