ASureevaA commited on
Commit
f6e6de6
·
1 Parent(s): f80380f
Files changed (1) hide show
  1. app.py +51 -149
app.py CHANGED
@@ -1,181 +1,83 @@
1
  from typing import Tuple, Optional
2
-
3
  import tempfile
4
-
5
- import numpy as np
6
- import soundfile as soundfile_module
7
  import torch
8
- import gradio as gradio_module
 
9
  from PIL import Image
10
  from transformers import (
11
  TrOCRProcessor,
12
  VisionEncoderDecoderModel,
13
  pipeline,
14
- VitsModel,
15
  AutoTokenizer,
 
16
  )
17
 
18
- ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
19
- "microsoft/trocr-small-printed"
20
- )
21
- ocr_model: VisionEncoderDecoderModel = VisionEncoderDecoderModel.from_pretrained(
22
- "microsoft/trocr-small-printed"
23
- )
24
 
25
  summary_pipeline = pipeline(
26
- task="summarization",
27
- model="sshleifer/distilbart-cnn-12-6",
 
28
  )
29
 
30
- tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-rus")
31
- tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
32
-
33
- device_string: str = "cpu"
34
- ocr_model.to(device_string)
35
- tts_model.to(device_string)
36
 
37
-
38
- def run_ocr(image_object: Image.Image) -> str:
39
- """
40
- Распознавание текста с изображения.
41
- Предполагаем, что на картинке простой напечатанный текст.
42
- """
43
- if image_object is None:
44
  return ""
45
-
46
- processor_output = ocr_processor(
47
- images=image_object,
48
- return_tensors="pt",
49
- )
50
- pixel_values_tensor = processor_output.pixel_values.to(device_string)
51
-
52
- generated_id_tensor = ocr_model.generate(pixel_values_tensor)
53
- decoded_text_list = ocr_processor.batch_decode(
54
- generated_id_tensor,
55
- skip_special_tokens=True,
56
- )
57
-
58
- recognized_text: str = decoded_text_list[0]
59
- return recognized_text.strip()
60
-
61
-
62
- def run_summarization(
63
- input_text: str,
64
- max_summary_tokens: int = 128,
65
- ) -> str:
66
- """
67
- Суммаризация текста до короткого конспекта.
68
- Без сложного разбиения на чанки -> длинные тексты лучше не кормить.
69
- """
70
- cleaned_text: str = input_text.strip()
71
- if not cleaned_text:
72
  return ""
 
 
73
 
74
- word_count: int = len(cleaned_text.split())
75
- dynamic_max_length: int = min(
76
- max_summary_tokens,
77
- max(32, word_count + 20),
78
- )
79
-
80
- summary_result_list = summary_pipeline(
81
- cleaned_text,
82
- max_length=dynamic_max_length,
83
- min_length=max(10, dynamic_max_length // 3),
84
- do_sample=False,
85
- )
86
-
87
- summary_text: str = summary_result_list[0]["summary_text"].strip()
88
- return summary_text
89
-
90
-
91
- def run_tts(summary_text: str) -> Optional[str]:
92
- """
93
- Озвучка текста конспекта через VitsModel (facebook/mms-tts-rus).
94
- Возвращаем путь до временного .wav файла, который Gradio отдаст в плеер.
95
- """
96
- cleaned_text: str = summary_text.strip()
97
- if not cleaned_text:
98
  return None
99
-
100
- tokenized_inputs = tts_tokenizer(
101
- cleaned_text,
102
- return_tensors="pt",
103
- ).to(device_string)
104
-
105
  with torch.no_grad():
106
- model_output = tts_model(**tokenized_inputs)
107
- waveform_tensor = model_output.waveform
108
-
109
- waveform_array = waveform_tensor.squeeze().cpu().numpy().astype("float32")
110
-
111
- with tempfile.NamedTemporaryFile(
112
- suffix=".wav",
113
- delete=False,
114
- ) as temporary_file:
115
- soundfile_module.write(
116
- temporary_file.name,
117
- waveform_array,
118
- tts_model.config.sampling_rate,
119
- )
120
- file_path: str = temporary_file.name
121
 
122
- return file_path
 
 
 
 
123
 
124
 
125
- def full_flow(
126
- image_object: Image.Image,
127
- max_summary_tokens: int = 128,
128
- ) -> Tuple[str, str, Optional[str]]:
129
-
130
- recognized_text: str = run_ocr(image_object=image_object)
131
-
132
- summary_text: str = run_summarization(
133
- input_text=recognized_text,
134
- max_summary_tokens=max_summary_tokens,
135
- )
136
-
137
- audio_file_path: Optional[str] = run_tts(summary_text=summary_text)
138
-
139
- return recognized_text, summary_text, audio_file_path
140
-
141
-
142
- gradio_interface = gradio_module.Interface(
143
  fn=full_flow,
144
- inputs=[
145
- gradio_module.Image(
146
- type="pil",
147
- label="Изображение с напечатанным текстом (лучше русским/латиницей)",
148
- ),
149
- gradio_module.Slider(
150
- minimum=32,
151
- maximum=256,
152
- value=128,
153
- step=16,
154
- label="Максимальная длина конспекта (токены, примерно)",
155
- ),
156
- ],
157
  outputs=[
158
- gradio_module.Textbox(
159
- label="Распознанный текст (OCR)",
160
- lines=6,
161
- ),
162
- gradio_module.Textbox(
163
- label="Конспект (суммаризация)",
164
- lines=6,
165
- ),
166
- gradio_module.Audio(
167
- label="Озвучка конспекта (VITS, ru)",
168
- type="filepath",
169
- ),
170
  ],
171
- title="Картинка → Конспект → Озвучка (Transformers)",
172
  description=(
173
- "1) Трансформер OCR распознаёт текст с изображения. "
174
- "2) Трансформер суммаризации сокращает текст до конспекта. "
175
- "3) VITS-модель (facebook/mms-tts-rus) озвучивает конспект по-русски."
176
  ),
177
  )
178
 
179
-
180
  if __name__ == "__main__":
181
- gradio_interface.launch()
 
1
  from typing import Tuple, Optional
 
2
  import tempfile
3
+ import soundfile as sf
 
 
4
  import torch
5
+ import gradio as gr
6
+ import numpy as np
7
  from PIL import Image
8
  from transformers import (
9
  TrOCRProcessor,
10
  VisionEncoderDecoderModel,
11
  pipeline,
 
12
  AutoTokenizer,
13
+ VitsModel,
14
  )
15
 
16
+
17
+ ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1")
18
+ ocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
19
+ ocr_model.to("cpu")
 
 
20
 
21
  summary_pipeline = pipeline(
22
+ "summarization",
23
+ model="IlyaGusev/mbart_ru_sum_gazeta",
24
+ tokenizer="IlyaGusev/mbart_ru_sum_gazeta",
25
  )
26
 
27
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-rus")
28
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
29
+ tts_model.to("cpu")
 
 
 
30
 
31
+ def run_ocr(image: Image.Image) -> str:
32
+ if image is None:
 
 
 
 
 
33
  return ""
34
+ pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
35
+ generated_ids = ocr_model.generate(pixel_values)
36
+ text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
37
+ return text.strip()
38
+
39
+ def run_summary(text: str) -> str:
40
+ text = text.strip()
41
+ if not text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  return ""
43
+ result = summary_pipeline(text, max_length=128, min_length=30, do_sample=False)
44
+ return result[0]["summary_text"].strip()
45
 
46
+ def run_tts(text: str) -> Optional[str]:
47
+ text = text.strip()
48
+ if not text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return None
50
+ inputs = tts_tokenizer(text, return_tensors="pt").to("cpu")
 
 
 
 
 
51
  with torch.no_grad():
52
+ waveform = tts_model(**inputs).waveform
53
+ audio = waveform.squeeze().cpu().numpy().astype("float32")
54
+ audio = np.clip(audio, -1.0, 1.0)
55
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
56
+ sf.write(f.name, audio, tts_model.config.sampling_rate)
57
+ return f.name
 
 
 
 
 
 
 
 
 
58
 
59
+ def full_flow(image: Image.Image) -> Tuple[str, str, Optional[str]]:
60
+ text = run_ocr(image)
61
+ summary = run_summary(text)
62
+ audio_path = run_tts(summary)
63
+ return text, summary, audio_path
64
 
65
 
66
+ demo = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  fn=full_flow,
68
+ inputs=gr.Image(type="pil", label="Изображение с текстом (русский или английский)"),
 
 
 
 
 
 
 
 
 
 
 
 
69
  outputs=[
70
+ gr.Textbox(label="Распознанный текст", lines=6),
71
+ gr.Textbox(label="Краткий пересказ", lines=6),
72
+ gr.Audio(label="Озвучка конспекта", type="filepath"),
 
 
 
 
 
 
 
 
 
73
  ],
74
+ title="Картинка → Текст → Конспект → Озвучка (русская версия)",
75
  description=(
76
+ "1️⃣ OCR (TrOCR-base) распознаёт текст с картинки.\n"
77
+ "2️⃣ Суммаризация (IlyaGusev/mbart_ru_sum_gazeta) делает конспект.\n"
78
+ "3️⃣ TTS (facebook/mms-tts-rus) озвучивает результат."
79
  ),
80
  )
81
 
 
82
  if __name__ == "__main__":
83
+ demo.launch()