ASureevaA commited on
Commit
9eec39f
·
1 Parent(s): f6e6de6
Files changed (1) hide show
  1. app.py +29 -8
app.py CHANGED
@@ -14,8 +14,12 @@ from transformers import (
14
  )
15
 
16
 
17
- ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1")
18
- ocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
 
 
 
 
19
  ocr_model.to("cpu")
20
 
21
  summary_pipeline = pipeline(
@@ -28,13 +32,30 @@ tts_model = VitsModel.from_pretrained("facebook/mms-tts-rus")
28
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
29
  tts_model.to("cpu")
30
 
31
- def run_ocr(image: Image.Image) -> str:
32
- if image is None:
 
 
 
 
33
  return ""
34
- pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
35
- generated_ids = ocr_model.generate(pixel_values)
36
- text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
37
- return text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def run_summary(text: str) -> str:
40
  text = text.strip()
 
14
  )
15
 
16
 
17
+ ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
18
+ "raxtemur/trocr-base-ru"
19
+ )
20
+ ocr_model: VisionEncoderDecoderModel = VisionEncoderDecoderModel.from_pretrained(
21
+ "raxtemur/trocr-base-ru"
22
+ )
23
  ocr_model.to("cpu")
24
 
25
  summary_pipeline = pipeline(
 
32
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
33
  tts_model.to("cpu")
34
 
35
+ def run_ocr(image_object: Image.Image) -> str:
36
+ """
37
+ Распознавание текста с изображения.
38
+ Предполагаем, что на картинке русский/кириллический или латинский печатный текст.
39
+ """
40
+ if image_object is None:
41
  return ""
42
+
43
+ rgb_image_object: Image.Image = image_object.convert("RGB")
44
+
45
+ processor_output = ocr_processor(
46
+ images=rgb_image_object,
47
+ return_tensors="pt",
48
+ )
49
+ pixel_values_tensor = processor_output.pixel_values.to("cpu")
50
+
51
+ generated_id_tensor = ocr_model.generate(pixel_values_tensor)
52
+ decoded_text_list = ocr_processor.batch_decode(
53
+ generated_id_tensor,
54
+ skip_special_tokens=True,
55
+ )
56
+
57
+ recognized_text: str = decoded_text_list[0]
58
+ return recognized_text.strip()
59
 
60
  def run_summary(text: str) -> str:
61
  text = text.strip()