from inference import model2annotations, traverse_by_dict, init_model import os import shutil import gradio as gr import PIL.Image import torch import numpy as np import cv2 import json DESCRIPTION = "# [comic-text-detector](https://github.com/dmMaze/comic-text-detector)" INPUT_DIR = "./input" OUTPUT_DIR = "./output" TEMP_DIR = "./temp" os.makedirs(INPUT_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True) def inference(model): img_dir = './input' save_dir = './output' result = model2annotations(img_dir, save_dir, save_json=False, model=model) #traverse_by_dict(img_dir, save_dir) return result current_directory = os.path.dirname(os.path.abspath(__file__)) model_path = './data/comictextdetector.pt.onnx' model = init_model(model_path, device = 'cpu') LAMA_MODEL_PATH = "big-lama.pt" if os.path.exists(LAMA_MODEL_PATH): lama_model = torch.jit.load(LAMA_MODEL_PATH, map_location='cpu') lama_model.eval() else: lama_model = None def inpaint(image, blk_dict_list): if lama_model is None: return image h, w = image.shape[:2] mask = np.zeros((h, w), dtype=np.uint8) for item in blk_dict_list: x1, y1, x2, y2 = item['xyxy'] cv2.rectangle(mask, (max(0, int(x1)-8), max(0, int(y1)-8)), (min(w, int(x2)+8), min(h, int(y2)+8)), 255, -1) mod_h, mod_w = (h // 8) * 8, (w // 8) * 8 img_t = torch.from_numpy(cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), (mod_w, mod_h))).float().permute(2, 0, 1).unsqueeze(0) / 255.0 mask_t = torch.from_numpy(cv2.resize(mask, (mod_w, mod_h))).float().unsqueeze(0).unsqueeze(0) / 255.0 with torch.no_grad(): inpainted_t = lama_model(img_t, mask_t) inpainted_img = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) inpainted_img = cv2.resize(inpainted_img, (w, h)) return cv2.cvtColor(inpainted_img, cv2.COLOR_RGB2BGR) def process_image_and_generate_zip(image_file: PIL.Image.Image) -> tuple[str | None, PIL.Image.Image | None]: if image_file is None: return "请上传一张图片!", None # 1. 清空 ./input 文件夹 print(f"清空 {INPUT_DIR} 文件夹...") for filename in os.listdir(INPUT_DIR): file_path = os.path.join(INPUT_DIR, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print(f"无法删除 {file_path}. 原因: {e}") print(f"{INPUT_DIR} 文件夹清空完成。") # 2. 清空 ./output 文件夹 (通常在每次运行时也清空输出) print(f"清空 {OUTPUT_DIR} 文件夹...") for filename in os.listdir(OUTPUT_DIR): file_path = os.path.join(OUTPUT_DIR, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print(f"无法删除 {file_path}. 原因: {e}") print(f"{OUTPUT_DIR} 文件夹清空完成。") # 3. 保存输入图片到 ./input 文件夹 base_filename = 'input.png' input_image_path = os.path.join(INPUT_DIR, base_filename) image_file.save(input_image_path) print(f"输入图片已保存到: {input_image_path}") # 4. 调用模型 print("调用模型...") result_json = inference(model) # 5. Inpaint (消除文本) print("执行文本消除...") # 将 PIL 图像转换为 OpenCV 格式 (BGR) img_cv = cv2.cvtColor(np.array(image_file), cv2.COLOR_RGB2BGR) # 解析检测结果 try: data = json.loads(result_json) if data and len(data) > 0: blk_dict_list = data[0] inpainted_cv = inpaint(img_cv, blk_dict_list) # 转换回 PIL 格式 (RGB) inpainted_pil = PIL.Image.fromarray(cv2.cvtColor(inpainted_cv, cv2.COLOR_BGR2RGB)) else: inpainted_pil = image_file except Exception as e: print(f"Inpaint 失败: {e}") inpainted_pil = image_file return result_json, inpainted_pil with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) gr.Markdown("上传一张图片,模型将对其进行文本检测并使用 Lama 进行文本消除。") with gr.Row(): image_input = gr.Image(label="上传图片", type="pil") image_output = gr.Image(label="消去文本后的图片", type="pil") run_button = gr.Button("运行模型") message_output = gr.Textbox(label="检测结果 (JSON)", interactive=False) run_button.click( fn=process_image_and_generate_zip, inputs=image_input, outputs=[message_output, image_output], api_name='detect' ) # 启动 Gradio 应用 if __name__ == "__main__": demo.queue().launch()