File size: 4,103 Bytes
f96b5a9
 
3a261c5
f96b5a9
 
 
3a261c5
 
f96b5a9
981b0ab
 
 
 
 
f96b5a9
981b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a261c5
981b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9160a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

import os
os.environ['TORCH_CUDA_ARCH_LIST']="7.5;8.6;9.0;9.0a"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["GRADIO_TEMP_DIR"] = "./gradio_tmp"

import spaces

import os.path as osp
import torch
import cv2
import numpy as np
import time
import gradio as gr


from models.TextEnhancement import MARCONetPlus
from utils.utils_image import imread_uint, uint2tensor4, tensor2uint
from networks.rrdbnet2_arch import RRDBNet as BSRGAN

# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Background restoration model (lazy loading)
BGModel = None
def load_bg_model():
    """Load BSRGAN model for background super-resolution"""
    global BGModel
    if BGModel is None:
        BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2)
        model_old = torch.load('./checkpoints/bsrgan_bg.pth', map_location=device)
        state_dict = BGModel.state_dict()
        for ((key, param), (key2, _)) in zip(model_old.items(), state_dict.items()):
            state_dict[key2] = param
        BGModel.load_state_dict(state_dict, strict=True)
        BGModel.eval()
        for k, v in BGModel.named_parameters():
            v.requires_grad = False
        BGModel = BGModel.to(device)

# Text restoration model
TextModel = MARCONetPlus(
    './checkpoints/net_w_encoder_860000.pth',
    './checkpoints/net_prior_860000.pth',
    './checkpoints/net_sr_860000.pth',
    './checkpoints/yolo11m_short_character.pt',
    device=device
)

@spaces.GPU(duration=120)
def gradio_inference(input_img, aligned=False, bg_sr=False, scale_factor=2):
    """Run MARCONetPlus inference with optional background SR"""
    if input_img is None:
        return None

    # Convert input image (PIL) to OpenCV format
    img_L = cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)
    height_L, width_L = img_L.shape[:2]

    # Background super-resolution
    if not aligned and bg_sr:
        load_bg_model()
        img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation=cv2.INTER_AREA)
        img_E = uint2tensor4(img_E).to(device)
        with torch.no_grad():
            try:
                img_E = BGModel(img_E)
            except:
                torch.cuda.empty_cache()
                max_size = 1536
                scale = min(max_size / width_L, max_size / height_L, 1.0)
                new_width = int(width_L * scale)
                new_height = int(height_L * scale)
                img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA)
                img_E = uint2tensor4(img_E).to(device)
                img_E = BGModel(img_E)
        img_E = tensor2uint(img_E)
    else:
        img_E = img_L

    # Resize background
    width_S = width_L * scale_factor
    height_S = height_L * scale_factor
    img_E = cv2.resize(img_E, (width_S, height_S), interpolation=cv2.INTER_AREA)

    # Text restoration
    SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts(
        img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned
    )

    if SQ is None:
        return None

    if not aligned:
        SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA)
        out_img = SQ[:, :, ::-1].astype(np.uint8)
    else:
        out_img = en_texts[0][:, :, ::-1].astype(np.uint8)

    return out_img

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# MARCONetPlus Text Image Restoration")

    with gr.Row():
        input_img = gr.Image(type="pil", label="Input Image")
        output_img = gr.Image(type="numpy", label="Restored Output")

    with gr.Row():
        aligned = gr.Checkbox(label="Aligned (cropped text regions)", value=False)
        bg_sr = gr.Checkbox(label="Background SR (BSRGAN)", value=False)
        scale_factor = gr.Slider(1, 4, value=2, step=1, label="Scale Factor")

    run_btn = gr.Button("Run Inference")

    run_btn.click(
        fn=gradio_inference,
        inputs=[input_img, aligned, bg_sr, scale_factor],
        outputs=[output_img]
    )

if __name__ == "__main__":
    demo.launch(share=True)