csxmli's picture
Update app.py
9eba4f2 verified
raw
history blame
4.04 kB
import torch
import cv2
import numpy as np
import os
import os.path as osp
import time
import gradio as gr
os.environ['TORCH_CUDA_ARCH_LIST']="7.5;8.6"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["GRADIO_TEMP_DIR"] = "./gradio_tmp"
from models.TextEnhancement import MARCONetPlus
from utils.utils_image import imread_uint, uint2tensor4, tensor2uint
from networks.rrdbnet2_arch import RRDBNet as BSRGAN
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Background restoration model (lazy loading)
BGModel = None
def load_bg_model():
"""Load BSRGAN model for background super-resolution"""
global BGModel
if BGModel is None:
BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2)
model_old = torch.load('./checkpoints/bsrgan_bg.pth', map_location=device)
state_dict = BGModel.state_dict()
for ((key, param), (key2, _)) in zip(model_old.items(), state_dict.items()):
state_dict[key2] = param
BGModel.load_state_dict(state_dict, strict=True)
BGModel.eval()
for k, v in BGModel.named_parameters():
v.requires_grad = False
BGModel = BGModel.to(device)
# Text restoration model
TextModel = MARCONetPlus(
'./checkpoints/net_w_encoder_860000.pth',
'./checkpoints/net_prior_860000.pth',
'./checkpoints/net_sr_860000.pth',
'./checkpoints/yolo11m_short_character.pt',
device=device
)
def gradio_inference(input_img, aligned=False, bg_sr=False, scale_factor=2):
"""Run MARCONetPlus inference with optional background SR"""
if input_img is None:
return None
# Convert input image (PIL) to OpenCV format
img_L = cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)
height_L, width_L = img_L.shape[:2]
# Background super-resolution
if not aligned and bg_sr:
load_bg_model()
img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation=cv2.INTER_AREA)
img_E = uint2tensor4(img_E).to(device)
with torch.no_grad():
try:
img_E = BGModel(img_E)
except:
torch.cuda.empty_cache()
max_size = 1536
scale = min(max_size / width_L, max_size / height_L, 1.0)
new_width = int(width_L * scale)
new_height = int(height_L * scale)
img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA)
img_E = uint2tensor4(img_E).to(device)
img_E = BGModel(img_E)
img_E = tensor2uint(img_E)
else:
img_E = img_L
# Resize background
width_S = width_L * scale_factor
height_S = height_L * scale_factor
img_E = cv2.resize(img_E, (width_S, height_S), interpolation=cv2.INTER_AREA)
# Text restoration
SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts(
img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned
)
if SQ is None:
return None
if not aligned:
SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA)
out_img = SQ[:, :, ::-1].astype(np.uint8)
else:
out_img = en_texts[0][:, :, ::-1].astype(np.uint8)
return out_img
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# MARCONetPlus Text Image Restoration")
with gr.Row():
input_img = gr.Image(type="pil", label="Input Image")
output_img = gr.Image(type="numpy", label="Restored Output")
with gr.Row():
aligned = gr.Checkbox(label="Aligned (cropped text regions)", value=False)
bg_sr = gr.Checkbox(label="Background SR (BSRGAN)", value=False)
scale_factor = gr.Slider(1, 4, value=2, step=1, label="Scale Factor")
run_btn = gr.Button("Run Inference")
run_btn.click(
fn=gradio_inference,
inputs=[input_img, aligned, bg_sr, scale_factor],
outputs=[output_img]
)
if __name__ == "__main__":
demo.launch()