Spaces:
Running
Running
| from inference import model2annotations, traverse_by_dict, init_model | |
| import os | |
| import shutil | |
| import gradio as gr | |
| import PIL.Image | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import json | |
| DESCRIPTION = "# [comic-text-detector](https://github.com/dmMaze/comic-text-detector)" | |
| INPUT_DIR = "./input" | |
| OUTPUT_DIR = "./output" | |
| TEMP_DIR = "./temp" | |
| os.makedirs(INPUT_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| def inference(model): | |
| img_dir = './input' | |
| save_dir = './output' | |
| result = model2annotations(img_dir, save_dir, save_json=False, model=model) | |
| #traverse_by_dict(img_dir, save_dir) | |
| return result | |
| current_directory = os.path.dirname(os.path.abspath(__file__)) | |
| model_path = './data/comictextdetector.pt.onnx' | |
| model = init_model(model_path, device = 'cpu') | |
| LAMA_MODEL_PATH = "big-lama.pt" | |
| if os.path.exists(LAMA_MODEL_PATH): | |
| lama_model = torch.jit.load(LAMA_MODEL_PATH, map_location='cpu') | |
| lama_model.eval() | |
| else: | |
| lama_model = None | |
| def inpaint(image, blk_dict_list): | |
| if lama_model is None: | |
| return image | |
| h, w = image.shape[:2] | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| for item in blk_dict_list: | |
| x1, y1, x2, y2 = item['xyxy'] | |
| cv2.rectangle(mask, (max(0, int(x1)-8), max(0, int(y1)-8)), (min(w, int(x2)+8), min(h, int(y2)+8)), 255, -1) | |
| mod_h, mod_w = (h // 8) * 8, (w // 8) * 8 | |
| img_t = torch.from_numpy(cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), (mod_w, mod_h))).float().permute(2, 0, 1).unsqueeze(0) / 255.0 | |
| mask_t = torch.from_numpy(cv2.resize(mask, (mod_w, mod_h))).float().unsqueeze(0).unsqueeze(0) / 255.0 | |
| with torch.no_grad(): | |
| inpainted_t = lama_model(img_t, mask_t) | |
| inpainted_img = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) | |
| inpainted_img = cv2.resize(inpainted_img, (w, h)) | |
| return cv2.cvtColor(inpainted_img, cv2.COLOR_RGB2BGR) | |
| def process_image_and_generate_zip(image_file: PIL.Image.Image) -> tuple[str | None, PIL.Image.Image | None]: | |
| if image_file is None: | |
| return "请上传一张图片!", None | |
| # 1. 清空 ./input 文件夹 | |
| print(f"清空 {INPUT_DIR} 文件夹...") | |
| for filename in os.listdir(INPUT_DIR): | |
| file_path = os.path.join(INPUT_DIR, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print(f"无法删除 {file_path}. 原因: {e}") | |
| print(f"{INPUT_DIR} 文件夹清空完成。") | |
| # 2. 清空 ./output 文件夹 (通常在每次运行时也清空输出) | |
| print(f"清空 {OUTPUT_DIR} 文件夹...") | |
| for filename in os.listdir(OUTPUT_DIR): | |
| file_path = os.path.join(OUTPUT_DIR, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print(f"无法删除 {file_path}. 原因: {e}") | |
| print(f"{OUTPUT_DIR} 文件夹清空完成。") | |
| # 3. 保存输入图片到 ./input 文件夹 | |
| base_filename = 'input.png' | |
| input_image_path = os.path.join(INPUT_DIR, base_filename) | |
| image_file.save(input_image_path) | |
| print(f"输入图片已保存到: {input_image_path}") | |
| # 4. 调用模型 | |
| print("调用模型...") | |
| result_json = inference(model) | |
| # 5. Inpaint (消除文本) | |
| print("执行文本消除...") | |
| # 将 PIL 图像转换为 OpenCV 格式 (BGR) | |
| img_cv = cv2.cvtColor(np.array(image_file), cv2.COLOR_RGB2BGR) | |
| # 解析检测结果 | |
| try: | |
| data = json.loads(result_json) | |
| if data and len(data) > 0: | |
| blk_dict_list = data[0] | |
| inpainted_cv = inpaint(img_cv, blk_dict_list) | |
| # 转换回 PIL 格式 (RGB) | |
| inpainted_pil = PIL.Image.fromarray(cv2.cvtColor(inpainted_cv, cv2.COLOR_BGR2RGB)) | |
| else: | |
| inpainted_pil = image_file | |
| except Exception as e: | |
| print(f"Inpaint 失败: {e}") | |
| inpainted_pil = image_file | |
| return result_json, inpainted_pil | |
| with gr.Blocks() as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.Markdown("上传一张图片,模型将对其进行文本检测并使用 Lama 进行文本消除。") | |
| with gr.Row(): | |
| image_input = gr.Image(label="上传图片", type="pil") | |
| image_output = gr.Image(label="消去文本后的图片", type="pil") | |
| run_button = gr.Button("运行模型") | |
| message_output = gr.Textbox(label="检测结果 (JSON)", interactive=False) | |
| run_button.click( | |
| fn=process_image_and_generate_zip, | |
| inputs=image_input, | |
| outputs=[message_output, image_output], | |
| api_name='detect' | |
| ) | |
| # 启动 Gradio 应用 | |
| if __name__ == "__main__": | |
| demo.queue().launch() |