ASureevaA commited on
Commit
b31d0e9
·
1 Parent(s): 9eec39f
Files changed (1) hide show
  1. app.py +146 -44
app.py CHANGED
@@ -1,9 +1,11 @@
1
  from typing import Tuple, Optional
 
2
  import tempfile
3
- import soundfile as sf
 
 
4
  import torch
5
- import gradio as gr
6
- import numpy as np
7
  from PIL import Image
8
  from transformers import (
9
  TrOCRProcessor,
@@ -14,28 +16,30 @@ from transformers import (
14
  )
15
 
16
 
 
 
17
  ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
18
  "raxtemur/trocr-base-ru"
19
  )
20
  ocr_model: VisionEncoderDecoderModel = VisionEncoderDecoderModel.from_pretrained(
21
  "raxtemur/trocr-base-ru"
22
  )
23
- ocr_model.to("cpu")
24
 
25
  summary_pipeline = pipeline(
26
- "summarization",
27
  model="IlyaGusev/mbart_ru_sum_gazeta",
28
  tokenizer="IlyaGusev/mbart_ru_sum_gazeta",
29
  )
30
 
31
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-rus")
32
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
33
- tts_model.to("cpu")
34
 
35
  def run_ocr(image_object: Image.Image) -> str:
36
  """
37
  Распознавание текста с изображения.
38
- Предполагаем, что на картинке русский/кириллический или латинский печатный текст.
39
  """
40
  if image_object is None:
41
  return ""
@@ -46,7 +50,7 @@ def run_ocr(image_object: Image.Image) -> str:
46
  images=rgb_image_object,
47
  return_tensors="pt",
48
  )
49
- pixel_values_tensor = processor_output.pixel_values.to("cpu")
50
 
51
  generated_id_tensor = ocr_model.generate(pixel_values_tensor)
52
  decoded_text_list = ocr_processor.batch_decode(
@@ -57,48 +61,146 @@ def run_ocr(image_object: Image.Image) -> str:
57
  recognized_text: str = decoded_text_list[0]
58
  return recognized_text.strip()
59
 
60
- def run_summary(text: str) -> str:
61
- text = text.strip()
62
- if not text:
 
 
 
 
 
 
 
63
  return ""
64
- result = summary_pipeline(text, max_length=128, min_length=30, do_sample=False)
65
- return result[0]["summary_text"].strip()
66
 
67
- def run_tts(text: str) -> Optional[str]:
68
- text = text.strip()
69
- if not text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  return None
71
- inputs = tts_tokenizer(text, return_tensors="pt").to("cpu")
72
- with torch.no_grad():
73
- waveform = tts_model(**inputs).waveform
74
- audio = waveform.squeeze().cpu().numpy().astype("float32")
75
- audio = np.clip(audio, -1.0, 1.0)
76
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
77
- sf.write(f.name, audio, tts_model.config.sampling_rate)
78
- return f.name
79
-
80
- def full_flow(image: Image.Image) -> Tuple[str, str, Optional[str]]:
81
- text = run_ocr(image)
82
- summary = run_summary(text)
83
- audio_path = run_tts(summary)
84
- return text, summary, audio_path
85
-
86
-
87
- demo = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  fn=full_flow,
89
- inputs=gr.Image(type="pil", label="Изображение с текстом (русский или английский)"),
 
 
 
 
 
 
 
 
 
 
 
 
90
  outputs=[
91
- gr.Textbox(label="Распознанный текст", lines=6),
92
- gr.Textbox(label="Краткий пересказ", lines=6),
93
- gr.Audio(label="Озвучка конспекта", type="filepath"),
 
 
 
 
 
 
 
 
 
94
  ],
95
- title="Картинка → Текст → Конспект → Озвучка (русская версия)",
96
  description=(
97
- "1️⃣ OCR (TrOCR-base) распознаёт текст с картинки.\n"
98
- "2️⃣ Суммаризация (IlyaGusev/mbart_ru_sum_gazeta) делает конспект.\n"
99
- "3️⃣ TTS (facebook/mms-tts-rus) озвучивает результат."
 
 
100
  ),
101
  )
102
 
 
103
  if __name__ == "__main__":
104
- demo.launch()
 
1
  from typing import Tuple, Optional
2
+
3
  import tempfile
4
+
5
+ import numpy as numpy_module
6
+ import soundfile as soundfile_module
7
  import torch
8
+ import gradio as gradio_module
 
9
  from PIL import Image
10
  from transformers import (
11
  TrOCRProcessor,
 
16
  )
17
 
18
 
19
+ device_string: str = "cpu"
20
+
21
  ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
22
  "raxtemur/trocr-base-ru"
23
  )
24
  ocr_model: VisionEncoderDecoderModel = VisionEncoderDecoderModel.from_pretrained(
25
  "raxtemur/trocr-base-ru"
26
  )
27
+ ocr_model.to(device_string)
28
 
29
  summary_pipeline = pipeline(
30
+ task="summarization",
31
  model="IlyaGusev/mbart_ru_sum_gazeta",
32
  tokenizer="IlyaGusev/mbart_ru_sum_gazeta",
33
  )
34
 
35
+ tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-rus")
36
+ tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
37
+ tts_model.to(device_string)
38
 
39
  def run_ocr(image_object: Image.Image) -> str:
40
  """
41
  Распознавание текста с изображения.
42
+ Используем русскую TrOCR-модель.
43
  """
44
  if image_object is None:
45
  return ""
 
50
  images=rgb_image_object,
51
  return_tensors="pt",
52
  )
53
+ pixel_values_tensor = processor_output.pixel_values.to(device_string)
54
 
55
  generated_id_tensor = ocr_model.generate(pixel_values_tensor)
56
  decoded_text_list = ocr_processor.batch_decode(
 
61
  recognized_text: str = decoded_text_list[0]
62
  return recognized_text.strip()
63
 
64
+ def run_summarization(
65
+ input_text: str,
66
+ max_summary_tokens: int = 128,
67
+ ) -> str:
68
+ """
69
+ Русская суммаризация.
70
+ Без разбиения на чанки, так что огромные тексты лучше не подавать.
71
+ """
72
+ cleaned_text: str = input_text.strip()
73
+ if not cleaned_text:
74
  return ""
 
 
75
 
76
+ word_count: int = len(cleaned_text.split())
77
+ dynamic_max_length: int = min(
78
+ max_summary_tokens,
79
+ max(32, word_count + 20),
80
+ )
81
+
82
+ summary_result_list = summary_pipeline(
83
+ cleaned_text,
84
+ max_length=dynamic_max_length,
85
+ min_length=max(16, dynamic_max_length // 3),
86
+ do_sample=False,
87
+ )
88
+
89
+ summary_text: str = summary_result_list[0]["summary_text"].strip()
90
+ return summary_text
91
+
92
+
93
+ def run_tts(summary_text: str) -> Optional[str]:
94
+ """
95
+ Озвучка текста конспекта через VitsModel (facebook/mms-tts-rus).
96
+
97
+ ВАЖНО:
98
+ - защищаемся от пустого/битого ввода;
99
+ - ловим RuntimeError изнутри модели (известные проблемы MMS VITS на некоторых входах);
100
+ в этом случае просто возвращаем None, чтобы не ронять весь Space.
101
+ """
102
+ cleaned_text: str = summary_text.strip()
103
+ if not cleaned_text:
104
+ return None
105
+
106
+ tokenized_inputs = tts_tokenizer(
107
+ cleaned_text,
108
+ return_tensors="pt",
109
+ )
110
+ tokenized_inputs = {key: value.to(device_string) for key, value in tokenized_inputs.items()}
111
+
112
+ input_ids_tensor = tokenized_inputs.get("input_ids")
113
+ if input_ids_tensor is None:
114
  return None
115
+
116
+ if input_ids_tensor.numel() == 0 or input_ids_tensor.shape[1] == 0:
117
+ return None
118
+
119
+ try:
120
+ with torch.no_grad():
121
+ model_output = tts_model(**tokenized_inputs)
122
+ waveform_tensor = model_output.waveform # shape: (batch, n_samples)
123
+ except RuntimeError as runtime_error:
124
+ print(f"[WARN] TTS RuntimeError: {runtime_error}")
125
+ return None
126
+
127
+ waveform_array = waveform_tensor.squeeze().cpu().numpy().astype("float32")
128
+ waveform_array = numpy_module.clip(waveform_array, -1.0, 1.0)
129
+
130
+ with tempfile.NamedTemporaryFile(
131
+ suffix=".wav",
132
+ delete=False,
133
+ ) as temporary_file:
134
+ soundfile_module.write(
135
+ temporary_file.name,
136
+ waveform_array,
137
+ tts_model.config.sampling_rate,
138
+ )
139
+ file_path: str = temporary_file.name
140
+
141
+ return file_path
142
+
143
+ def full_flow(
144
+ image_object: Image.Image,
145
+ max_summary_tokens: int = 128,
146
+ ) -> Tuple[str, str, Optional[str]]:
147
+ """
148
+ Полный пайплайн:
149
+ 1) OCR: изображение -> исходный текст
150
+ 2) Суммаризация: текст -> конспект
151
+ 3) TTS: конспект -> .wav файл (или None, если TTS не смог)
152
+ """
153
+ recognized_text: str = run_ocr(image_object=image_object)
154
+
155
+ summary_text: str = run_summarization(
156
+ input_text=recognized_text,
157
+ max_summary_tokens=max_summary_tokens,
158
+ )
159
+
160
+ audio_file_path: Optional[str] = run_tts(summary_text=summary_text)
161
+
162
+ return recognized_text, summary_text, audio_file_path
163
+
164
+
165
+ gradio_interface = gradio_module.Interface(
166
  fn=full_flow,
167
+ inputs=[
168
+ gradio_module.Image(
169
+ type="pil",
170
+ label="Изображение с текстом (желательно русский/английский, печатный)",
171
+ ),
172
+ gradio_module.Slider(
173
+ minimum=32,
174
+ maximum=256,
175
+ value=128,
176
+ step=16,
177
+ label="Максимальная длина конспекта (токены, примерно)",
178
+ ),
179
+ ],
180
  outputs=[
181
+ gradio_module.Textbox(
182
+ label="Распознанный текст (OCR)",
183
+ lines=6,
184
+ ),
185
+ gradio_module.Textbox(
186
+ label="Конспект (суммаризация)",
187
+ lines=6,
188
+ ),
189
+ gradio_module.Audio(
190
+ label="Озвучка конспекта (MMS VITS, ru)",
191
+ type="filepath",
192
+ ),
193
  ],
194
+ title="Картинка → Текст → Конспект → Озвучка (русские модели)",
195
  description=(
196
+ "1) Русский трансформер OCR распознаёт текст с картинки.\n"
197
+ "2) Русский трансформер суммаризации делает краткий пересказ.\n"
198
+ "3) VITS-модель MMS (facebook/mms-tts-rus) озвучивает конспект.\n\n"
199
+ "Если озвучка не сгенерировалась, значит конкретный текст не понравился TTS-модели "
200
+ "и она упала внутри — пайплайн просто пропустит аудио."
201
  ),
202
  )
203
 
204
+
205
  if __name__ == "__main__":
206
+ gradio_interface.launch()