Manuscript-OCR / app.py
konstantinkozhin's picture
Update app.py
bef2e74 verified
import os
import time
import tempfile
import gradio as gr
import numpy as np
from PIL import Image
from manuscript import CharLM, Pipeline
from manuscript.detectors import EAST, YOLO
from manuscript.recognizers import TRBA
# --------------------------
# Hugging Face Spaces cache
# --------------------------
# Повторяем логику из старой версии скрипта, чтобы кэш сохранялся в persistent
# storage, когда он доступен. Это полезно для ускорения загрузки моделей
# при перезапусках контейнера.
def _setup_cache_dirs():
base = "/data" if (os.path.isdir("/data") and os.access("/data", os.W_OK)) else "/tmp"
cache_root = os.path.join(base, "cache")
os.makedirs(cache_root, exist_ok=True)
os.environ.setdefault("XDG_CACHE_HOME", cache_root)
os.environ.setdefault("HF_HOME", os.path.join(cache_root, "hf"))
os.environ.setdefault("TORCH_HOME", os.path.join(cache_root, "torch"))
os.environ.setdefault("MPLCONFIGDIR", os.path.join(cache_root, "mpl"))
_setup_cache_dirs()
# --------------------------
# Модели и дефолтные параметры
# --------------------------
DETECTOR_MODELS = [
"yolo26x_obb_text_g1",
"yolo26s_obb_text_g1",
"east_50_g1",
]
RECOGNIZER_MODELS = [
"trba_lite_g2",
"trba_lite_g1",
"trba_base_g1",
]
CORRECTOR_MODELS = [
"prereform_charlm_g1",
"modern_charlm_g1",
]
DETECTOR_DEFAULTS = {
"yolo26x_obb_text_g1": {
"target_size": 1408,
"score_thresh": 0.1,
"expand_ratio_w": 1.4,
"expand_ratio_h": 1.5,
"is_east": False,
},
"yolo26s_obb_text_g1": {
"target_size": 1408,
"score_thresh": 0.1,
"expand_ratio_w": 1.4,
"expand_ratio_h": 1.5,
"is_east": False,
},
"east_50_g1": {
"target_size": 1280,
"score_thresh": 0.6,
"expand_ratio_w": 1.4,
"expand_ratio_h": 1.5,
"is_east": True,
},
}
# --------------------------
# Служебные переменные
# --------------------------
last_recognition_page = None
last_correction_page = None
# --------------------------
# Pipeline helpers
# --------------------------
def create_pipeline(
detector_model: str,
recognizer_model: str,
corrector_model: str,
target_size: int,
score_thresh: float,
expand_ratio_w: float,
expand_ratio_h: float,
mask_threshold: float,
apply_threshold: float,
max_edits: int,
):
# Выбираем детектор в зависимости от модели
if detector_model.startswith("east_"):
detector = EAST(
weights=detector_model,
target_size=target_size,
score_thresh=score_thresh,
expand_ratio_w=expand_ratio_w,
expand_ratio_h=expand_ratio_h,
)
else:
detector = YOLO(
weights=detector_model,
target_size=target_size,
score_thresh=score_thresh,
)
recognizer = TRBA(weights=recognizer_model)
corrector = CharLM(
weights=corrector_model,
mask_threshold=mask_threshold,
apply_threshold=apply_threshold,
max_edits=max_edits,
)
return Pipeline(detector=detector, recognizer=recognizer, corrector=corrector)
def update_detector_controls(detector_model: str):
"""Обновление ползунков при смене детектора (аналогично новой версии UI)."""
cfg = DETECTOR_DEFAULTS[detector_model]
is_east = cfg["is_east"]
return (
gr.update(
value=cfg["target_size"],
interactive=False if not is_east else True,
label="Размер изображения" if is_east else "Размер изображения (по модели)",
),
gr.update(value=cfg["score_thresh"]),
gr.update(value=cfg["expand_ratio_w"], visible=is_east, interactive=is_east),
gr.update(value=cfg["expand_ratio_h"], visible=is_east, interactive=is_east),
)
def count_words_in_page(page):
if page is None:
return 0
return sum(
1 for block in page.blocks for line in block.lines for span in line.text_spans if span.text
)
def highlight_differences(original: str, corrected: str) -> str:
"""Подсветка исправленных символов зелёным, как в старой версии."""
html = []
i, j = 0, 0
while i < len(original) or j < len(corrected):
if i < len(original) and j < len(corrected):
if original[i] == corrected[j]:
html.append("<br>" if original[i] == "\n" else corrected[j])
i += 1
j += 1
else:
html.append(
f'<span style="background-color:#90EE90;font-weight:bold;">{corrected[j]}</span>'
)
i += 1
j += 1
elif i < len(original):
i += 1
else:
html.append(
f'<span style="background-color:#90EE90;font-weight:bold;">{corrected[j]}</span>'
)
j += 1
return f'<div style="white-space:pre-wrap;font-family:monospace;">{"".join(html)}</div>'
# --------------------------
# Основная обработка изображения
# --------------------------
def process_image(
image,
detector_model,
recognizer_model,
corrector_model,
target_size,
score_thresh,
expand_ratio_w,
expand_ratio_h,
mask_threshold,
apply_threshold,
max_edits,
):
global last_recognition_page, last_correction_page
if image is None:
return None, "", "", ""
try:
pipeline = create_pipeline(
detector_model,
recognizer_model,
corrector_model,
int(target_size),
float(score_thresh),
float(expand_ratio_w),
float(expand_ratio_h),
float(mask_threshold),
float(apply_threshold),
int(max_edits),
)
start_time = time.time()
_, vis_image = pipeline.predict(image, vis=True)
elapsed_time = time.time() - start_time
last_recognition_page = pipeline.last_recognition_page
last_correction_page = pipeline.last_correction_page
text_before = pipeline.get_text(last_recognition_page) if last_recognition_page else ""
text_after = pipeline.get_text(last_correction_page) if last_correction_page else ""
word_count = count_words_in_page(last_correction_page)
pages_per_sec = 1.0 / elapsed_time if elapsed_time > 0 else 0.0
words_per_sec = word_count / elapsed_time if elapsed_time > 0 else 0.0
stats_text = (
f"Время: {elapsed_time:.2f} сек | "
f"{pages_per_sec:.2f} стр/сек | "
f"{words_per_sec:.1f} слов/сек"
)
if isinstance(vis_image, np.ndarray):
vis_image = Image.fromarray(vis_image)
highlighted = highlight_differences(text_before, text_after)
return vis_image, text_before, highlighted, stats_text
except Exception as e:
error_msg = f"Ошибка: {e}"
return None, error_msg, "", error_msg
# --------------------------
# Сохранение json-файлов
# --------------------------
def save_recognition_json():
global last_recognition_page
if last_recognition_page is None:
return None
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
f.write(last_recognition_page.to_json())
return f.name
def save_correction_json():
global last_correction_page
if last_correction_page is None:
return None
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
f.write(last_correction_page.to_json())
return f.name
# --------------------------
# Примеры изображений (как в старой версии)
# --------------------------
EXAMPLE_IMAGES = [
"examples/img1.jpeg",
"examples/img2.jpeg",
"examples/img3.jpeg",
"examples/img4.jpeg",
"examples/img5.jpeg",
"examples/img6.png",
]
# --------------------------
# UI (Gradio Blocks + Soft theme)
# --------------------------
def build_demo():
default_detector = "yolo26x_obb_text_g1"
default_cfg = DETECTOR_DEFAULTS[default_detector]
with gr.Blocks(title="OCR Pipeline", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Manuscript Demo")
with gr.Row():
# ------------------
# Левая колонка
# ------------------
with gr.Column():
input_image = gr.Image(label="Изображение", type="pil")
# Блок примеров — возвращаем из старого интерфейса
gr.Examples(
examples=EXAMPLE_IMAGES,
inputs=input_image,
label="Примеры (кликни, чтобы загрузить)",
)
# Выбор моделей
with gr.Row():
detector_selector = gr.Dropdown(
choices=DETECTOR_MODELS,
value=default_detector,
label="Детектор",
)
recognizer_selector = gr.Dropdown(
choices=RECOGNIZER_MODELS,
value="trba_lite_g2",
label="Распознаватель",
)
corrector_selector = gr.Dropdown(
choices=CORRECTOR_MODELS,
value="prereform_charlm_g1",
label="Корректор",
)
# Параметры детектора
with gr.Accordion("Параметры детектора", open=False):
target_size = gr.Slider(
640,
2560,
value=default_cfg["target_size"],
step=64,
label="Размер изображения (по модели)",
interactive=False,
)
score_thresh = gr.Slider(
0.1,
0.9,
value=default_cfg["score_thresh"],
step=0.05,
label="Порог уверенности",
)
expand_ratio_w = gr.Slider(
0.5,
3.0,
value=default_cfg["expand_ratio_w"],
step=0.1,
label="Расширение по ширине (EAST)",
visible=False,
interactive=False,
)
expand_ratio_h = gr.Slider(
0.5,
3.0,
value=default_cfg["expand_ratio_h"],
step=0.1,
label="Расширение по высоте (EAST)",
visible=False,
interactive=False,
)
# Параметры корректора
with gr.Accordion("Параметры корректора", open=False):
mask_threshold = gr.Slider(
0.0, 0.5, value=0.05, step=0.01, label="Порог маскирования"
)
apply_threshold = gr.Slider(
0.5, 1.0, value=0.95, step=0.01, label="Порог применения"
)
max_edits = gr.Slider(1, 10, value=1, step=1, label="Максимум правок")
# Кнопка запуска
btn = gr.Button("Распознать", variant="primary")
# ------------------
# Правая колонка
# ------------------
with gr.Column():
output_image = gr.Image(label="Визуализация", type="pil")
stats_display = gr.Textbox(label="Статистика", interactive=False)
# ------------------
# Нижний блок с текстом
# ------------------
with gr.Row():
with gr.Column():
text_before = gr.Textbox(label="Текст без корректора", lines=10)
btn_save_recognition = gr.Button("Сохранить в JSON")
file_recognition = gr.File(label="Результат распознавания")
with gr.Column():
text_after = gr.HTML(label="Текст с корректором")
btn_save_correction = gr.Button("Сохранить в JSON")
file_correction = gr.File(label="Результат коррекции")
# ------------------
# Логика взаимодействий
# ------------------
detector_selector.change(
update_detector_controls,
inputs=[detector_selector],
outputs=[target_size, score_thresh, expand_ratio_w, expand_ratio_h],
)
btn.click(
process_image,
inputs=[
input_image,
detector_selector,
recognizer_selector,
corrector_selector,
target_size,
score_thresh,
expand_ratio_w,
expand_ratio_h,
mask_threshold,
apply_threshold,
max_edits,
],
outputs=[output_image, text_before, text_after, stats_display],
)
btn_save_recognition.click(
save_recognition_json,
inputs=[],
outputs=[file_recognition],
)
btn_save_correction.click(
save_correction_json,
inputs=[],
outputs=[file_correction],
)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch(share=True)