ASureevaA commited on
Commit
a88eb1e
·
1 Parent(s): 170ad3a
Files changed (2) hide show
  1. app.py +164 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Any
2
+
3
+ import torch
4
+ import numpy as numpy
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from datasets import load_dataset
8
+ from transformers import (
9
+ TrOCRProcessor,
10
+ VisionEncoderDecoderModel,
11
+ pipeline,
12
+ )
13
+
14
+ ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
15
+ "microsoft/trocr-small-printed"
16
+ )
17
+ ocr_model: VisionEncoderDecoderModel = VisionEncoderDecoderModel.from_pretrained(
18
+ "microsoft/trocr-small-printed"
19
+ )
20
+
21
+ summary_pipeline = pipeline(
22
+ task="summarization",
23
+ model="sshleifer/distilbart-cnn-12-6",
24
+ )
25
+
26
+ tts_pipeline = pipeline(
27
+ task="text-to-speech",
28
+ model="microsoft/speecht5_tts",
29
+ )
30
+
31
+ speaker_dataset = load_dataset(
32
+ path="Matthijs/cmu-arctic-xvectors",
33
+ split="validation",
34
+ )
35
+ speaker_embedding_tensor: torch.Tensor = torch.tensor(
36
+ speaker_dataset[7306]["xvector"]
37
+ ).unsqueeze(0)
38
+
39
+ def run_ocr(image_object: Image.Image) -> str:
40
+ """
41
+ Распознавание текста с изображения с помощью трансформера OCR.
42
+ Предполагаем, что на картинке простой напечатанный текст.
43
+ """
44
+ if image_object is None:
45
+ return ""
46
+
47
+ processor_output = ocr_processor(
48
+ images=image_object,
49
+ return_tensors="pt",
50
+ )
51
+ pixel_values_tensor = processor_output.pixel_values
52
+
53
+ generated_id_tensor = ocr_model.generate(pixel_values_tensor)
54
+ decoded_text_list = ocr_processor.batch_decode(
55
+ generated_id_tensor,
56
+ skip_special_tokens=True,
57
+ )
58
+
59
+ recognized_text: str = decoded_text_list[0]
60
+ return recognized_text.strip()
61
+
62
+
63
+ def run_summarization(
64
+ input_text: str,
65
+ max_summary_tokens: int = 128,
66
+ ) -> str:
67
+ """
68
+ Суммаризация текста.
69
+ Здесь без разбиения на чанки, поэтому для очень длинных текстов могут быть проблемы.
70
+ """
71
+ cleaned_text: str = input_text.strip()
72
+ if not cleaned_text:
73
+ return ""
74
+
75
+ summary_result_list = summary_pipeline(
76
+ cleaned_text,
77
+ max_length=max_summary_tokens,
78
+ min_length=max(16, max_summary_tokens // 3),
79
+ do_sample=False,
80
+ )
81
+
82
+ summary_text: str = summary_result_list[0]["summary_text"].strip()
83
+ return summary_text
84
+
85
+
86
+ def run_tts(summary_text: str) -> Optional[Tuple[int, Any]]:
87
+ """
88
+ Озвучка текста конспекта.
89
+ Возвращает кортеж (частота_дискретизации, аудиоданные) или None, если текста нет.
90
+ """
91
+ cleaned_text: str = summary_text.strip()
92
+ if not cleaned_text:
93
+ return None
94
+
95
+ tts_output = tts_pipeline(
96
+ cleaned_text,
97
+ forward_params={"speaker_embeddings": speaker_embedding_tensor},
98
+ )
99
+
100
+ sampling_rate_int: int = int(tts_output["sampling_rate"])
101
+ audio_array = tts_output["audio"]
102
+
103
+ if not isinstance(audio_array, numpy.ndarray):
104
+ audio_array = numpy.array(audio_array)
105
+
106
+ return sampling_rate_int, audio_array
107
+
108
+
109
+ def full_flow(
110
+ image_object: Image.Image,
111
+ max_summary_tokens: int = 128,
112
+ ) -> Tuple[str, str, Optional[Tuple[int, Any]]]:
113
+
114
+ recognized_text: str = run_ocr(image_object=image_object)
115
+
116
+ summary_text: str = run_summarization(
117
+ input_text=recognized_text,
118
+ max_summary_tokens=max_summary_tokens,
119
+ )
120
+
121
+ audio_tuple = run_tts(summary_text=summary_text)
122
+
123
+ return recognized_text, summary_text, audio_tuple
124
+
125
+ gradio_interface = gr.Interface(
126
+ fn=full_flow,
127
+ inputs=[
128
+ gr.Image(
129
+ type="pil",
130
+ label="Изображение с напечатанным текстом (английский)",
131
+ ),
132
+ gr.Slider(
133
+ minimum=32,
134
+ maximum=256,
135
+ value=128,
136
+ step=16,
137
+ label="Максимальная длина конспекта (токены, примерно)",
138
+ ),
139
+ ],
140
+ outputs=[
141
+ gr.Textbox(
142
+ label="Распознанный текст (OCR)",
143
+ lines=6,
144
+ ),
145
+ gr.Textbox(
146
+ label="Конспект (суммаризация)",
147
+ lines=6,
148
+ ),
149
+ gr.Audio(
150
+ label="Озвучка конспекта (TTS)",
151
+ type="numpy",
152
+ ),
153
+ ],
154
+ title="Картинка → Конспект → Озвучка (Transformers)",
155
+ description=(
156
+ "1) Трансформер OCR распознаёт текст с изображения. "
157
+ "2) Трансформер суммаризации сокращает текст до конспекта. "
158
+ "3) Трансформер TTS озвучивает конспект."
159
+ ),
160
+ )
161
+
162
+
163
+ if __name__ == "__main__":
164
+ gradio_interface.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ datasets
4
+ sentencepiece
5
+ soundfile
6
+ gradio
7
+ Pillow
8
+ numpy