ASureevaA commited on
Commit
f80380f
·
1 Parent(s): 7585a19
Files changed (2) hide show
  1. app.py +55 -25
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,11 +1,18 @@
1
- from typing import Tuple, Optional, Any
 
 
 
2
  import numpy as np
 
 
3
  import gradio as gradio_module
4
  from PIL import Image
5
  from transformers import (
6
  TrOCRProcessor,
7
  VisionEncoderDecoderModel,
8
  pipeline,
 
 
9
  )
10
 
11
  ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
@@ -20,10 +27,12 @@ summary_pipeline = pipeline(
20
  model="sshleifer/distilbart-cnn-12-6",
21
  )
22
 
23
- text_to_speech_pipeline = pipeline(
24
- task="text-to-speech",
25
- model="facebook/mms-tts-eng",
26
- )
 
 
27
 
28
 
29
  def run_ocr(image_object: Image.Image) -> str:
@@ -38,7 +47,7 @@ def run_ocr(image_object: Image.Image) -> str:
38
  images=image_object,
39
  return_tensors="pt",
40
  )
41
- pixel_values_tensor = processor_output.pixel_values
42
 
43
  generated_id_tensor = ocr_model.generate(pixel_values_tensor)
44
  decoded_text_list = ocr_processor.batch_decode(
@@ -56,16 +65,22 @@ def run_summarization(
56
  ) -> str:
57
  """
58
  Суммаризация текста до короткого конспекта.
59
- Без разбиения на чанки, поэтому огромные тексты лучше не подавать.
60
  """
61
  cleaned_text: str = input_text.strip()
62
  if not cleaned_text:
63
  return ""
64
 
 
 
 
 
 
 
65
  summary_result_list = summary_pipeline(
66
  cleaned_text,
67
- max_length=max_summary_tokens,
68
- min_length=max(16, max_summary_tokens // 3),
69
  do_sample=False,
70
  )
71
 
@@ -73,30 +88,44 @@ def run_summarization(
73
  return summary_text
74
 
75
 
76
- def run_tts(summary_text: str) -> Optional[Tuple[int, Any]]:
77
  """
78
- Озвучка текста конспекта.
79
- Используем модель, которой не нужны внешние speaker embeddings.
80
  """
81
  cleaned_text: str = summary_text.strip()
82
  if not cleaned_text:
83
  return None
84
 
85
- tts_output = text_to_speech_pipeline(cleaned_text)
 
 
 
86
 
87
- sampling_rate_int: int = int(tts_output["sampling_rate"])
88
- audio_array = tts_output["audio"]
 
89
 
90
- audio_array = np.array(audio_array, dtype=np.float32)
91
- audio_array = np.clip(audio_array, -1.0, 1.0)
92
 
93
- return sampling_rate_int, audio_array
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  def full_flow(
97
  image_object: Image.Image,
98
  max_summary_tokens: int = 128,
99
- ) -> Tuple[str, str, Optional[Tuple[int, Any]]]:
100
 
101
  recognized_text: str = run_ocr(image_object=image_object)
102
 
@@ -105,16 +134,17 @@ def full_flow(
105
  max_summary_tokens=max_summary_tokens,
106
  )
107
 
108
- audio_tuple = run_tts(summary_text=summary_text)
 
 
109
 
110
- return recognized_text, summary_text, audio_tuple
111
 
112
  gradio_interface = gradio_module.Interface(
113
  fn=full_flow,
114
  inputs=[
115
  gradio_module.Image(
116
  type="pil",
117
- label="Изображение с напечатанным текстом (английский)",
118
  ),
119
  gradio_module.Slider(
120
  minimum=32,
@@ -134,15 +164,15 @@ gradio_interface = gradio_module.Interface(
134
  lines=6,
135
  ),
136
  gradio_module.Audio(
137
- label="Озвучка конспекта (TTS)",
138
- type="numpy",
139
  ),
140
  ],
141
  title="Картинка → Конспект → Озвучка (Transformers)",
142
  description=(
143
  "1) Трансформер OCR распознаёт текст с изображения. "
144
  "2) Трансформер суммаризации сокращает текст до конспекта. "
145
- "3) Трансформер TTS озвучивает конспект."
146
  ),
147
  )
148
 
 
1
+ from typing import Tuple, Optional
2
+
3
+ import tempfile
4
+
5
  import numpy as np
6
+ import soundfile as soundfile_module
7
+ import torch
8
  import gradio as gradio_module
9
  from PIL import Image
10
  from transformers import (
11
  TrOCRProcessor,
12
  VisionEncoderDecoderModel,
13
  pipeline,
14
+ VitsModel,
15
+ AutoTokenizer,
16
  )
17
 
18
  ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
 
27
  model="sshleifer/distilbart-cnn-12-6",
28
  )
29
 
30
+ tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-rus")
31
+ tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
32
+
33
+ device_string: str = "cpu"
34
+ ocr_model.to(device_string)
35
+ tts_model.to(device_string)
36
 
37
 
38
  def run_ocr(image_object: Image.Image) -> str:
 
47
  images=image_object,
48
  return_tensors="pt",
49
  )
50
+ pixel_values_tensor = processor_output.pixel_values.to(device_string)
51
 
52
  generated_id_tensor = ocr_model.generate(pixel_values_tensor)
53
  decoded_text_list = ocr_processor.batch_decode(
 
65
  ) -> str:
66
  """
67
  Суммаризация текста до короткого конспекта.
68
+ Без сложного разбиения на чанки -> длинные тексты лучше не кормить.
69
  """
70
  cleaned_text: str = input_text.strip()
71
  if not cleaned_text:
72
  return ""
73
 
74
+ word_count: int = len(cleaned_text.split())
75
+ dynamic_max_length: int = min(
76
+ max_summary_tokens,
77
+ max(32, word_count + 20),
78
+ )
79
+
80
  summary_result_list = summary_pipeline(
81
  cleaned_text,
82
+ max_length=dynamic_max_length,
83
+ min_length=max(10, dynamic_max_length // 3),
84
  do_sample=False,
85
  )
86
 
 
88
  return summary_text
89
 
90
 
91
+ def run_tts(summary_text: str) -> Optional[str]:
92
  """
93
+ Озвучка текста конспекта через VitsModel (facebook/mms-tts-rus).
94
+ Возвращаем путь до временного .wav файла, который Gradio отдаст в плеер.
95
  """
96
  cleaned_text: str = summary_text.strip()
97
  if not cleaned_text:
98
  return None
99
 
100
+ tokenized_inputs = tts_tokenizer(
101
+ cleaned_text,
102
+ return_tensors="pt",
103
+ ).to(device_string)
104
 
105
+ with torch.no_grad():
106
+ model_output = tts_model(**tokenized_inputs)
107
+ waveform_tensor = model_output.waveform
108
 
109
+ waveform_array = waveform_tensor.squeeze().cpu().numpy().astype("float32")
 
110
 
111
+ with tempfile.NamedTemporaryFile(
112
+ suffix=".wav",
113
+ delete=False,
114
+ ) as temporary_file:
115
+ soundfile_module.write(
116
+ temporary_file.name,
117
+ waveform_array,
118
+ tts_model.config.sampling_rate,
119
+ )
120
+ file_path: str = temporary_file.name
121
+
122
+ return file_path
123
 
124
 
125
  def full_flow(
126
  image_object: Image.Image,
127
  max_summary_tokens: int = 128,
128
+ ) -> Tuple[str, str, Optional[str]]:
129
 
130
  recognized_text: str = run_ocr(image_object=image_object)
131
 
 
134
  max_summary_tokens=max_summary_tokens,
135
  )
136
 
137
+ audio_file_path: Optional[str] = run_tts(summary_text=summary_text)
138
+
139
+ return recognized_text, summary_text, audio_file_path
140
 
 
141
 
142
  gradio_interface = gradio_module.Interface(
143
  fn=full_flow,
144
  inputs=[
145
  gradio_module.Image(
146
  type="pil",
147
+ label="Изображение с напечатанным текстом (лучше русским/латиницей)",
148
  ),
149
  gradio_module.Slider(
150
  minimum=32,
 
164
  lines=6,
165
  ),
166
  gradio_module.Audio(
167
+ label="Озвучка конспекта (VITS, ru)",
168
+ type="filepath",
169
  ),
170
  ],
171
  title="Картинка → Конспект → Озвучка (Transformers)",
172
  description=(
173
  "1) Трансформер OCR распознаёт текст с изображения. "
174
  "2) Трансформер суммаризации сокращает текст до конспекта. "
175
+ "3) VITS-модель (facebook/mms-tts-rus) озвучивает конспект по-русски."
176
  ),
177
  )
178
 
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- transformers
2
  torch
3
  sentencepiece
4
  gradio
5
  Pillow
6
  numpy
 
 
1
+ transformers>=4.33.0
2
  torch
3
  sentencepiece
4
  gradio
5
  Pillow
6
  numpy
7
+ soundfile