Spaces:
Running
Running
| import os | |
| import time | |
| import tempfile | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from manuscript import CharLM, Pipeline | |
| from manuscript.detectors import EAST, YOLO | |
| from manuscript.recognizers import TRBA | |
| # -------------------------- | |
| # Hugging Face Spaces cache | |
| # -------------------------- | |
| # Повторяем логику из старой версии скрипта, чтобы кэш сохранялся в persistent | |
| # storage, когда он доступен. Это полезно для ускорения загрузки моделей | |
| # при перезапусках контейнера. | |
| def _setup_cache_dirs(): | |
| base = "/data" if (os.path.isdir("/data") and os.access("/data", os.W_OK)) else "/tmp" | |
| cache_root = os.path.join(base, "cache") | |
| os.makedirs(cache_root, exist_ok=True) | |
| os.environ.setdefault("XDG_CACHE_HOME", cache_root) | |
| os.environ.setdefault("HF_HOME", os.path.join(cache_root, "hf")) | |
| os.environ.setdefault("TORCH_HOME", os.path.join(cache_root, "torch")) | |
| os.environ.setdefault("MPLCONFIGDIR", os.path.join(cache_root, "mpl")) | |
| _setup_cache_dirs() | |
| # -------------------------- | |
| # Модели и дефолтные параметры | |
| # -------------------------- | |
| DETECTOR_MODELS = [ | |
| "yolo26x_obb_text_g1", | |
| "yolo26s_obb_text_g1", | |
| "east_50_g1", | |
| ] | |
| RECOGNIZER_MODELS = [ | |
| "trba_lite_g2", | |
| "trba_lite_g1", | |
| "trba_base_g1", | |
| ] | |
| CORRECTOR_MODELS = [ | |
| "prereform_charlm_g1", | |
| "modern_charlm_g1", | |
| ] | |
| DETECTOR_DEFAULTS = { | |
| "yolo26x_obb_text_g1": { | |
| "target_size": 1408, | |
| "score_thresh": 0.1, | |
| "expand_ratio_w": 1.4, | |
| "expand_ratio_h": 1.5, | |
| "is_east": False, | |
| }, | |
| "yolo26s_obb_text_g1": { | |
| "target_size": 1408, | |
| "score_thresh": 0.1, | |
| "expand_ratio_w": 1.4, | |
| "expand_ratio_h": 1.5, | |
| "is_east": False, | |
| }, | |
| "east_50_g1": { | |
| "target_size": 1280, | |
| "score_thresh": 0.6, | |
| "expand_ratio_w": 1.4, | |
| "expand_ratio_h": 1.5, | |
| "is_east": True, | |
| }, | |
| } | |
| # -------------------------- | |
| # Служебные переменные | |
| # -------------------------- | |
| last_recognition_page = None | |
| last_correction_page = None | |
| # -------------------------- | |
| # Pipeline helpers | |
| # -------------------------- | |
| def create_pipeline( | |
| detector_model: str, | |
| recognizer_model: str, | |
| corrector_model: str, | |
| target_size: int, | |
| score_thresh: float, | |
| expand_ratio_w: float, | |
| expand_ratio_h: float, | |
| mask_threshold: float, | |
| apply_threshold: float, | |
| max_edits: int, | |
| ): | |
| # Выбираем детектор в зависимости от модели | |
| if detector_model.startswith("east_"): | |
| detector = EAST( | |
| weights=detector_model, | |
| target_size=target_size, | |
| score_thresh=score_thresh, | |
| expand_ratio_w=expand_ratio_w, | |
| expand_ratio_h=expand_ratio_h, | |
| ) | |
| else: | |
| detector = YOLO( | |
| weights=detector_model, | |
| target_size=target_size, | |
| score_thresh=score_thresh, | |
| ) | |
| recognizer = TRBA(weights=recognizer_model) | |
| corrector = CharLM( | |
| weights=corrector_model, | |
| mask_threshold=mask_threshold, | |
| apply_threshold=apply_threshold, | |
| max_edits=max_edits, | |
| ) | |
| return Pipeline(detector=detector, recognizer=recognizer, corrector=corrector) | |
| def update_detector_controls(detector_model: str): | |
| """Обновление ползунков при смене детектора (аналогично новой версии UI).""" | |
| cfg = DETECTOR_DEFAULTS[detector_model] | |
| is_east = cfg["is_east"] | |
| return ( | |
| gr.update( | |
| value=cfg["target_size"], | |
| interactive=False if not is_east else True, | |
| label="Размер изображения" if is_east else "Размер изображения (по модели)", | |
| ), | |
| gr.update(value=cfg["score_thresh"]), | |
| gr.update(value=cfg["expand_ratio_w"], visible=is_east, interactive=is_east), | |
| gr.update(value=cfg["expand_ratio_h"], visible=is_east, interactive=is_east), | |
| ) | |
| def count_words_in_page(page): | |
| if page is None: | |
| return 0 | |
| return sum( | |
| 1 for block in page.blocks for line in block.lines for span in line.text_spans if span.text | |
| ) | |
| def highlight_differences(original: str, corrected: str) -> str: | |
| """Подсветка исправленных символов зелёным, как в старой версии.""" | |
| html = [] | |
| i, j = 0, 0 | |
| while i < len(original) or j < len(corrected): | |
| if i < len(original) and j < len(corrected): | |
| if original[i] == corrected[j]: | |
| html.append("<br>" if original[i] == "\n" else corrected[j]) | |
| i += 1 | |
| j += 1 | |
| else: | |
| html.append( | |
| f'<span style="background-color:#90EE90;font-weight:bold;">{corrected[j]}</span>' | |
| ) | |
| i += 1 | |
| j += 1 | |
| elif i < len(original): | |
| i += 1 | |
| else: | |
| html.append( | |
| f'<span style="background-color:#90EE90;font-weight:bold;">{corrected[j]}</span>' | |
| ) | |
| j += 1 | |
| return f'<div style="white-space:pre-wrap;font-family:monospace;">{"".join(html)}</div>' | |
| # -------------------------- | |
| # Основная обработка изображения | |
| # -------------------------- | |
| def process_image( | |
| image, | |
| detector_model, | |
| recognizer_model, | |
| corrector_model, | |
| target_size, | |
| score_thresh, | |
| expand_ratio_w, | |
| expand_ratio_h, | |
| mask_threshold, | |
| apply_threshold, | |
| max_edits, | |
| ): | |
| global last_recognition_page, last_correction_page | |
| if image is None: | |
| return None, "", "", "" | |
| try: | |
| pipeline = create_pipeline( | |
| detector_model, | |
| recognizer_model, | |
| corrector_model, | |
| int(target_size), | |
| float(score_thresh), | |
| float(expand_ratio_w), | |
| float(expand_ratio_h), | |
| float(mask_threshold), | |
| float(apply_threshold), | |
| int(max_edits), | |
| ) | |
| start_time = time.time() | |
| _, vis_image = pipeline.predict(image, vis=True) | |
| elapsed_time = time.time() - start_time | |
| last_recognition_page = pipeline.last_recognition_page | |
| last_correction_page = pipeline.last_correction_page | |
| text_before = pipeline.get_text(last_recognition_page) if last_recognition_page else "" | |
| text_after = pipeline.get_text(last_correction_page) if last_correction_page else "" | |
| word_count = count_words_in_page(last_correction_page) | |
| pages_per_sec = 1.0 / elapsed_time if elapsed_time > 0 else 0.0 | |
| words_per_sec = word_count / elapsed_time if elapsed_time > 0 else 0.0 | |
| stats_text = ( | |
| f"Время: {elapsed_time:.2f} сек | " | |
| f"{pages_per_sec:.2f} стр/сек | " | |
| f"{words_per_sec:.1f} слов/сек" | |
| ) | |
| if isinstance(vis_image, np.ndarray): | |
| vis_image = Image.fromarray(vis_image) | |
| highlighted = highlight_differences(text_before, text_after) | |
| return vis_image, text_before, highlighted, stats_text | |
| except Exception as e: | |
| error_msg = f"Ошибка: {e}" | |
| return None, error_msg, "", error_msg | |
| # -------------------------- | |
| # Сохранение json-файлов | |
| # -------------------------- | |
| def save_recognition_json(): | |
| global last_recognition_page | |
| if last_recognition_page is None: | |
| return None | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f: | |
| f.write(last_recognition_page.to_json()) | |
| return f.name | |
| def save_correction_json(): | |
| global last_correction_page | |
| if last_correction_page is None: | |
| return None | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f: | |
| f.write(last_correction_page.to_json()) | |
| return f.name | |
| # -------------------------- | |
| # Примеры изображений (как в старой версии) | |
| # -------------------------- | |
| EXAMPLE_IMAGES = [ | |
| "examples/img1.jpeg", | |
| "examples/img2.jpeg", | |
| "examples/img3.jpeg", | |
| "examples/img4.jpeg", | |
| "examples/img5.jpeg", | |
| "examples/img6.png", | |
| ] | |
| # -------------------------- | |
| # UI (Gradio Blocks + Soft theme) | |
| # -------------------------- | |
| def build_demo(): | |
| default_detector = "yolo26x_obb_text_g1" | |
| default_cfg = DETECTOR_DEFAULTS[default_detector] | |
| with gr.Blocks(title="OCR Pipeline", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Manuscript Demo") | |
| with gr.Row(): | |
| # ------------------ | |
| # Левая колонка | |
| # ------------------ | |
| with gr.Column(): | |
| input_image = gr.Image(label="Изображение", type="pil") | |
| # Блок примеров — возвращаем из старого интерфейса | |
| gr.Examples( | |
| examples=EXAMPLE_IMAGES, | |
| inputs=input_image, | |
| label="Примеры (кликни, чтобы загрузить)", | |
| ) | |
| # Выбор моделей | |
| with gr.Row(): | |
| detector_selector = gr.Dropdown( | |
| choices=DETECTOR_MODELS, | |
| value=default_detector, | |
| label="Детектор", | |
| ) | |
| recognizer_selector = gr.Dropdown( | |
| choices=RECOGNIZER_MODELS, | |
| value="trba_lite_g2", | |
| label="Распознаватель", | |
| ) | |
| corrector_selector = gr.Dropdown( | |
| choices=CORRECTOR_MODELS, | |
| value="prereform_charlm_g1", | |
| label="Корректор", | |
| ) | |
| # Параметры детектора | |
| with gr.Accordion("Параметры детектора", open=False): | |
| target_size = gr.Slider( | |
| 640, | |
| 2560, | |
| value=default_cfg["target_size"], | |
| step=64, | |
| label="Размер изображения (по модели)", | |
| interactive=False, | |
| ) | |
| score_thresh = gr.Slider( | |
| 0.1, | |
| 0.9, | |
| value=default_cfg["score_thresh"], | |
| step=0.05, | |
| label="Порог уверенности", | |
| ) | |
| expand_ratio_w = gr.Slider( | |
| 0.5, | |
| 3.0, | |
| value=default_cfg["expand_ratio_w"], | |
| step=0.1, | |
| label="Расширение по ширине (EAST)", | |
| visible=False, | |
| interactive=False, | |
| ) | |
| expand_ratio_h = gr.Slider( | |
| 0.5, | |
| 3.0, | |
| value=default_cfg["expand_ratio_h"], | |
| step=0.1, | |
| label="Расширение по высоте (EAST)", | |
| visible=False, | |
| interactive=False, | |
| ) | |
| # Параметры корректора | |
| with gr.Accordion("Параметры корректора", open=False): | |
| mask_threshold = gr.Slider( | |
| 0.0, 0.5, value=0.05, step=0.01, label="Порог маскирования" | |
| ) | |
| apply_threshold = gr.Slider( | |
| 0.5, 1.0, value=0.95, step=0.01, label="Порог применения" | |
| ) | |
| max_edits = gr.Slider(1, 10, value=1, step=1, label="Максимум правок") | |
| # Кнопка запуска | |
| btn = gr.Button("Распознать", variant="primary") | |
| # ------------------ | |
| # Правая колонка | |
| # ------------------ | |
| with gr.Column(): | |
| output_image = gr.Image(label="Визуализация", type="pil") | |
| stats_display = gr.Textbox(label="Статистика", interactive=False) | |
| # ------------------ | |
| # Нижний блок с текстом | |
| # ------------------ | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_before = gr.Textbox(label="Текст без корректора", lines=10) | |
| btn_save_recognition = gr.Button("Сохранить в JSON") | |
| file_recognition = gr.File(label="Результат распознавания") | |
| with gr.Column(): | |
| text_after = gr.HTML(label="Текст с корректором") | |
| btn_save_correction = gr.Button("Сохранить в JSON") | |
| file_correction = gr.File(label="Результат коррекции") | |
| # ------------------ | |
| # Логика взаимодействий | |
| # ------------------ | |
| detector_selector.change( | |
| update_detector_controls, | |
| inputs=[detector_selector], | |
| outputs=[target_size, score_thresh, expand_ratio_w, expand_ratio_h], | |
| ) | |
| btn.click( | |
| process_image, | |
| inputs=[ | |
| input_image, | |
| detector_selector, | |
| recognizer_selector, | |
| corrector_selector, | |
| target_size, | |
| score_thresh, | |
| expand_ratio_w, | |
| expand_ratio_h, | |
| mask_threshold, | |
| apply_threshold, | |
| max_edits, | |
| ], | |
| outputs=[output_image, text_before, text_after, stats_display], | |
| ) | |
| btn_save_recognition.click( | |
| save_recognition_json, | |
| inputs=[], | |
| outputs=[file_recognition], | |
| ) | |
| btn_save_correction.click( | |
| save_correction_json, | |
| inputs=[], | |
| outputs=[file_correction], | |
| ) | |
| return demo | |
| demo = build_demo() | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |