ASureevaA commited on
Commit
35e85d1
·
1 Parent(s): b31d0e9
Files changed (1) hide show
  1. app.py +80 -34
app.py CHANGED
@@ -11,35 +11,56 @@ from transformers import (
11
  TrOCRProcessor,
12
  VisionEncoderDecoderModel,
13
  pipeline,
14
- AutoTokenizer,
15
  VitsModel,
16
  )
17
 
 
 
 
18
 
19
- device_string: str = "cpu"
20
 
 
 
 
 
 
 
 
21
  ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
22
- "raxtemur/trocr-base-ru"
23
  )
24
  ocr_model: VisionEncoderDecoderModel = VisionEncoderDecoderModel.from_pretrained(
25
- "raxtemur/trocr-base-ru"
26
  )
27
  ocr_model.to(device_string)
28
 
 
 
29
  summary_pipeline = pipeline(
30
  task="summarization",
31
- model="IlyaGusev/mbart_ru_sum_gazeta",
32
- tokenizer="IlyaGusev/mbart_ru_sum_gazeta",
33
  )
34
 
35
- tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-rus")
36
- tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
 
 
37
  tts_model.to(device_string)
38
 
 
 
 
 
 
39
  def run_ocr(image_object: Image.Image) -> str:
40
  """
41
- Распознавание текста с изображения.
42
- Используем русскую TrOCR-модель.
 
 
 
43
  """
44
  if image_object is None:
45
  return ""
@@ -52,28 +73,38 @@ def run_ocr(image_object: Image.Image) -> str:
52
  )
53
  pixel_values_tensor = processor_output.pixel_values.to(device_string)
54
 
55
- generated_id_tensor = ocr_model.generate(pixel_values_tensor)
 
 
56
  decoded_text_list = ocr_processor.batch_decode(
57
  generated_id_tensor,
58
  skip_special_tokens=True,
59
  )
60
 
61
- recognized_text: str = decoded_text_list[0]
62
- return recognized_text.strip()
 
 
 
 
 
63
 
64
  def run_summarization(
65
  input_text: str,
66
  max_summary_tokens: int = 128,
67
  ) -> str:
68
  """
69
- Русская суммаризация.
70
- Без разбиения на чанки, так что огромные тексты лучше не подавать.
71
  """
72
  cleaned_text: str = input_text.strip()
73
  if not cleaned_text:
74
  return ""
75
 
76
  word_count: int = len(cleaned_text.split())
 
 
 
77
  dynamic_max_length: int = min(
78
  max_summary_tokens,
79
  max(32, word_count + 20),
@@ -82,7 +113,7 @@ def run_summarization(
82
  summary_result_list = summary_pipeline(
83
  cleaned_text,
84
  max_length=dynamic_max_length,
85
- min_length=max(16, dynamic_max_length // 3),
86
  do_sample=False,
87
  )
88
 
@@ -90,13 +121,17 @@ def run_summarization(
90
  return summary_text
91
 
92
 
 
 
 
 
93
  def run_tts(summary_text: str) -> Optional[str]:
94
  """
95
- Озвучка текста конспекта через VitsModel (facebook/mms-tts-rus).
96
 
97
  ВАЖНО:
98
- - защищаемся от пустого/битого ввода;
99
- - ловим RuntimeError изнутри модели (известные проблемы MMS VITS на некоторых входах);
100
  в это�� случае просто возвращаем None, чтобы не ронять весь Space.
101
  """
102
  cleaned_text: str = summary_text.strip()
@@ -107,7 +142,11 @@ def run_tts(summary_text: str) -> Optional[str]:
107
  cleaned_text,
108
  return_tensors="pt",
109
  )
110
- tokenized_inputs = {key: value.to(device_string) for key, value in tokenized_inputs.items()}
 
 
 
 
111
 
112
  input_ids_tensor = tokenized_inputs.get("input_ids")
113
  if input_ids_tensor is None:
@@ -140,15 +179,20 @@ def run_tts(summary_text: str) -> Optional[str]:
140
 
141
  return file_path
142
 
 
 
 
 
 
143
  def full_flow(
144
  image_object: Image.Image,
145
  max_summary_tokens: int = 128,
146
  ) -> Tuple[str, str, Optional[str]]:
147
  """
148
  Полный пайплайн:
149
- 1) OCR: изображение -> исходный текст
150
- 2) Суммаризация: текст -> конспект
151
- 3) TTS: конспект -> .wav файл (или None, если TTS не смог)
152
  """
153
  recognized_text: str = run_ocr(image_object=image_object)
154
 
@@ -162,42 +206,44 @@ def full_flow(
162
  return recognized_text, summary_text, audio_file_path
163
 
164
 
 
 
 
 
165
  gradio_interface = gradio_module.Interface(
166
  fn=full_flow,
167
  inputs=[
168
  gradio_module.Image(
169
  type="pil",
170
- label="Изображение с текстом (желательно русский/английский, печатный)",
171
  ),
172
  gradio_module.Slider(
173
  minimum=32,
174
  maximum=256,
175
  value=128,
176
  step=16,
177
- label="Максимальная длина конспекта (токены, примерно)",
178
  ),
179
  ],
180
  outputs=[
181
  gradio_module.Textbox(
182
- label="Распознанный текст (OCR)",
183
  lines=6,
184
  ),
185
  gradio_module.Textbox(
186
- label="Конспект (суммаризация)",
187
  lines=6,
188
  ),
189
  gradio_module.Audio(
190
- label="Озвучка конспекта (MMS VITS, ru)",
191
  type="filepath",
192
  ),
193
  ],
194
- title="КартинкаТекстКонспектОзвучка (русские модели)",
195
  description=(
196
- "1) Русский трансформер OCR распознаёт текст с картинки.\n"
197
- "2) Русский трансформер суммаризации делает краткий пересказ.\n"
198
- "3) VITS-модель MMS (facebook/mms-tts-rus) озвучивает конспект.\n\n"
199
- "Если озвучка не сгенерировалась, значит конкретный текст не понравился TTS-модели "
200
- "и она упала внутри — пайплайн просто пропустит аудио."
201
  ),
202
  )
203
 
 
11
  TrOCRProcessor,
12
  VisionEncoderDecoderModel,
13
  pipeline,
14
+ VitsTokenizer,
15
  VitsModel,
16
  )
17
 
18
+ # ============================
19
+ # 1. Настройки устройства
20
+ # ============================
21
 
22
+ device_string: str = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
+
25
+ # ============================
26
+ # 2. Модели
27
+ # ============================
28
+
29
+ # OCR: печатный английский текст
30
+ # Модель: microsoft/trocr-small-printed
31
  ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
32
+ "microsoft/trocr-small-printed"
33
  )
34
  ocr_model: VisionEncoderDecoderModel = VisionEncoderDecoderModel.from_pretrained(
35
+ "microsoft/trocr-small-printed"
36
  )
37
  ocr_model.to(device_string)
38
 
39
+ # Суммаризация: английский новостной/общий текст
40
+ # Модель: sshleifer/distilbart-cnn-12-6
41
  summary_pipeline = pipeline(
42
  task="summarization",
43
+ model="sshleifer/distilbart-cnn-12-6",
 
44
  )
45
 
46
+ # TTS: английская MMS VITS
47
+ # Модель: facebook/mms-tts-eng
48
+ tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-eng")
49
+ tts_tokenizer: VitsTokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
50
  tts_model.to(device_string)
51
 
52
+
53
+ # ============================
54
+ # 3. OCR
55
+ # ============================
56
+
57
  def run_ocr(image_object: Image.Image) -> str:
58
  """
59
+ Распознавание печатного английского текста с изображения.
60
+ Используем TrOCR (microsoft/trocr-small-printed).
61
+
62
+ Ожидается более-менее читаемый printed text
63
+ (скриншоты, документы, слайды и т.п.).
64
  """
65
  if image_object is None:
66
  return ""
 
73
  )
74
  pixel_values_tensor = processor_output.pixel_values.to(device_string)
75
 
76
+ with torch.no_grad():
77
+ generated_id_tensor = ocr_model.generate(pixel_values_tensor)
78
+
79
  decoded_text_list = ocr_processor.batch_decode(
80
  generated_id_tensor,
81
  skip_special_tokens=True,
82
  )
83
 
84
+ recognized_text: str = decoded_text_list[0].strip()
85
+ return recognized_text
86
+
87
+
88
+ # ============================
89
+ # 4. Суммаризация (английский)
90
+ # ============================
91
 
92
  def run_summarization(
93
  input_text: str,
94
  max_summary_tokens: int = 128,
95
  ) -> str:
96
  """
97
+ Английская суммаризация.
98
+ Без разбиения на чанки, поэтому очень длинные тексты лучше не подавать.
99
  """
100
  cleaned_text: str = input_text.strip()
101
  if not cleaned_text:
102
  return ""
103
 
104
  word_count: int = len(cleaned_text.split())
105
+
106
+ # Простая адаптация длины под размер текста,
107
+ # чтобы не было бессмысленных max_length >> input_length.
108
  dynamic_max_length: int = min(
109
  max_summary_tokens,
110
  max(32, word_count + 20),
 
113
  summary_result_list = summary_pipeline(
114
  cleaned_text,
115
  max_length=dynamic_max_length,
116
+ min_length=max(10, dynamic_max_length // 3),
117
  do_sample=False,
118
  )
119
 
 
121
  return summary_text
122
 
123
 
124
+ # ============================
125
+ # 5. TTS (английский, MMS VITS)
126
+ # ============================
127
+
128
  def run_tts(summary_text: str) -> Optional[str]:
129
  """
130
+ Озвучка английского текста конспекта через VitsModel (facebook/mms-tts-eng).
131
 
132
  ВАЖНО:
133
+ - защищаемся от пустого ввода;
134
+ - ловим RuntimeError изнутри модели (бывают краши на редких входах);
135
  в это�� случае просто возвращаем None, чтобы не ронять весь Space.
136
  """
137
  cleaned_text: str = summary_text.strip()
 
142
  cleaned_text,
143
  return_tensors="pt",
144
  )
145
+
146
+ tokenized_inputs = {
147
+ key: value.to(device_string)
148
+ for key, value in tokenized_inputs.items()
149
+ }
150
 
151
  input_ids_tensor = tokenized_inputs.get("input_ids")
152
  if input_ids_tensor is None:
 
179
 
180
  return file_path
181
 
182
+
183
+ # ============================
184
+ # 6. Полный пайплайн
185
+ # ============================
186
+
187
  def full_flow(
188
  image_object: Image.Image,
189
  max_summary_tokens: int = 128,
190
  ) -> Tuple[str, str, Optional[str]]:
191
  """
192
  Полный пайплайн:
193
+ 1) OCR: изображение -> исходный текст (английский)
194
+ 2) Суммаризация: текст -> краткое резюме
195
+ 3) TTS: резюме -> .wav файл (или None, если TTS не смог)
196
  """
197
  recognized_text: str = run_ocr(image_object=image_object)
198
 
 
206
  return recognized_text, summary_text, audio_file_path
207
 
208
 
209
+ # ============================
210
+ # 7. Gradio UI
211
+ # ============================
212
+
213
  gradio_interface = gradio_module.Interface(
214
  fn=full_flow,
215
  inputs=[
216
  gradio_module.Image(
217
  type="pil",
218
+ label="Image with printed English text",
219
  ),
220
  gradio_module.Slider(
221
  minimum=32,
222
  maximum=256,
223
  value=128,
224
  step=16,
225
+ label="Maximum summary length (tokens, approx)",
226
  ),
227
  ],
228
  outputs=[
229
  gradio_module.Textbox(
230
+ label="Recognized text (OCR)",
231
  lines=6,
232
  ),
233
  gradio_module.Textbox(
234
+ label="Summary (English)",
235
  lines=6,
236
  ),
237
  gradio_module.Audio(
238
+ label="Summary narration (MMS VITS, en)",
239
  type="filepath",
240
  ),
241
  ],
242
+ title="ImageTextSummarySpeech (English models)",
243
  description=(
244
+ "1) English OCR transformer recognizes printed text from the image.\n"
245
+ "2) English summarization transformer creates a short summary.\n"
246
+ "3) English VITS (facebook/mms-tts-eng) reads the summary aloud."
 
 
247
  ),
248
  )
249