detector / app.py
root
添加lama模型
2ba28a6
from inference import model2annotations, traverse_by_dict, init_model
import os
import shutil
import gradio as gr
import PIL.Image
import torch
import numpy as np
import cv2
import json
DESCRIPTION = "# [comic-text-detector](https://github.com/dmMaze/comic-text-detector)"
INPUT_DIR = "./input"
OUTPUT_DIR = "./output"
TEMP_DIR = "./temp"
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(TEMP_DIR, exist_ok=True)
def inference(model):
img_dir = './input'
save_dir = './output'
result = model2annotations(img_dir, save_dir, save_json=False, model=model)
#traverse_by_dict(img_dir, save_dir)
return result
current_directory = os.path.dirname(os.path.abspath(__file__))
model_path = './data/comictextdetector.pt.onnx'
model = init_model(model_path, device = 'cpu')
LAMA_MODEL_PATH = "big-lama.pt"
if os.path.exists(LAMA_MODEL_PATH):
lama_model = torch.jit.load(LAMA_MODEL_PATH, map_location='cpu')
lama_model.eval()
else:
lama_model = None
def inpaint(image, blk_dict_list):
if lama_model is None:
return image
h, w = image.shape[:2]
mask = np.zeros((h, w), dtype=np.uint8)
for item in blk_dict_list:
x1, y1, x2, y2 = item['xyxy']
cv2.rectangle(mask, (max(0, int(x1)-8), max(0, int(y1)-8)), (min(w, int(x2)+8), min(h, int(y2)+8)), 255, -1)
mod_h, mod_w = (h // 8) * 8, (w // 8) * 8
img_t = torch.from_numpy(cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), (mod_w, mod_h))).float().permute(2, 0, 1).unsqueeze(0) / 255.0
mask_t = torch.from_numpy(cv2.resize(mask, (mod_w, mod_h))).float().unsqueeze(0).unsqueeze(0) / 255.0
with torch.no_grad():
inpainted_t = lama_model(img_t, mask_t)
inpainted_img = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
inpainted_img = cv2.resize(inpainted_img, (w, h))
return cv2.cvtColor(inpainted_img, cv2.COLOR_RGB2BGR)
def process_image_and_generate_zip(image_file: PIL.Image.Image) -> tuple[str | None, PIL.Image.Image | None]:
if image_file is None:
return "请上传一张图片!", None
# 1. 清空 ./input 文件夹
print(f"清空 {INPUT_DIR} 文件夹...")
for filename in os.listdir(INPUT_DIR):
file_path = os.path.join(INPUT_DIR, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"无法删除 {file_path}. 原因: {e}")
print(f"{INPUT_DIR} 文件夹清空完成。")
# 2. 清空 ./output 文件夹 (通常在每次运行时也清空输出)
print(f"清空 {OUTPUT_DIR} 文件夹...")
for filename in os.listdir(OUTPUT_DIR):
file_path = os.path.join(OUTPUT_DIR, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"无法删除 {file_path}. 原因: {e}")
print(f"{OUTPUT_DIR} 文件夹清空完成。")
# 3. 保存输入图片到 ./input 文件夹
base_filename = 'input.png'
input_image_path = os.path.join(INPUT_DIR, base_filename)
image_file.save(input_image_path)
print(f"输入图片已保存到: {input_image_path}")
# 4. 调用模型
print("调用模型...")
result_json = inference(model)
# 5. Inpaint (消除文本)
print("执行文本消除...")
# 将 PIL 图像转换为 OpenCV 格式 (BGR)
img_cv = cv2.cvtColor(np.array(image_file), cv2.COLOR_RGB2BGR)
# 解析检测结果
try:
data = json.loads(result_json)
if data and len(data) > 0:
blk_dict_list = data[0]
inpainted_cv = inpaint(img_cv, blk_dict_list)
# 转换回 PIL 格式 (RGB)
inpainted_pil = PIL.Image.fromarray(cv2.cvtColor(inpainted_cv, cv2.COLOR_BGR2RGB))
else:
inpainted_pil = image_file
except Exception as e:
print(f"Inpaint 失败: {e}")
inpainted_pil = image_file
return result_json, inpainted_pil
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
gr.Markdown("上传一张图片,模型将对其进行文本检测并使用 Lama 进行文本消除。")
with gr.Row():
image_input = gr.Image(label="上传图片", type="pil")
image_output = gr.Image(label="消去文本后的图片", type="pil")
run_button = gr.Button("运行模型")
message_output = gr.Textbox(label="检测结果 (JSON)", interactive=False)
run_button.click(
fn=process_image_and_generate_zip,
inputs=image_input,
outputs=[message_output, image_output],
api_name='detect'
)
# 启动 Gradio 应用
if __name__ == "__main__":
demo.queue().launch()