ExtractIt / app.py
shawntan123's picture
Update app.py
cdf1c08 verified
import gradio as gr
import numpy as np
import cv2
import onnxruntime as ort
from PIL import Image, ImageEnhance, ImageFilter, ImageDraw, ImageFont
from rembg import remove, new_session
import easyocr
import torch
from spandrel import ModelLoader
import os
import onnxruntime as ort
from transformers import pipeline, AutoModelForImageSegmentation, AutoConfig
import json
# --- 全局变量与模型缓存 ---
models = {
"birefnet": None,
"bria_session": None, # 必须添加这一行
"ocr": None,
"anime_session": None,
"upscaler": None
}
def identify_objects(input_img):
global models
if input_img is None:
return None, json.dumps({"error": "请上传图片"}, ensure_ascii=False)
try:
if "object_detector" not in models:
from transformers import pipeline
models["object_detector"] = pipeline(
"object-detection",
model="hustvl/yolos-tiny",
device=-1
)
img_pil = input_img.convert("RGB")
results = models["object_detector"](img_pil)
draw = ImageDraw.Draw(img_pil)
# 准备一个列表存储结构化数据
detections = []
count = 0
for res in results:
if res["score"] > 0.5:
count += 1
box = res["box"]
label = res["label"]
score = res["score"]
# 绘制逻辑
draw.rectangle([box["xmin"], box["ymin"], box["xmax"], box["ymax"]], outline="red", width=4)
# 将信息添加到列表
detections.append({
"id": count,
"label": label,
"confidence": round(score, 4), # 保留4位小数
"box": {
"xmin": int(box["xmin"]),
"ymin": int(box["ymin"]),
"xmax": int(box["xmax"]),
"ymax": int(box["ymax"])
}
})
# 构造最终的 JSON 响应
output_json = {
"status": "success",
"count": count,
"detections": detections
}
# 使用 json.dumps 转换为字符串,ensure_ascii=False 保证中文不乱码
return img_pil, json.dumps(output_json, ensure_ascii=False, indent=4)
except Exception as e:
error_json = {
"status": "error",
"message": str(e)
}
return input_img, json.dumps(error_json, ensure_ascii=False)
def process_rembg(input_img):
global models
if input_img is None: return None
try:
# 加载模型逻辑
if models["birefnet"] is None:
from transformers import AutoModelForImageSegmentation
model_id = "ZhengPeng7/BiRefNet_lite"
models["birefnet"] = AutoModelForImageSegmentation.from_pretrained(
model_id, trust_remote_code=True
)
# 补丁:防止某些版本 transformers 报错
if not hasattr(models["birefnet"], "all_tied_weights_keys"):
models["birefnet"].all_tied_weights_keys = []
models["birefnet"].to("cpu").eval()
# 图像处理
img_pil = input_img.convert("RGB")
w, h = img_pil.size
# 预处理:1024x1024 归一化
img_resized = img_pil.resize((1024, 1024), Image.BILINEAR)
img_np = np.array(img_resized).astype(np.float32) / 255.0
# 标准 ImageNet 归一化参数
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
img_np = (img_np - mean) / std
img_np = img_np.transpose((2, 0, 1)) # HWC -> CHW
# --- 终极修复:强制 float() ---
img_tensor = torch.from_numpy(img_np).unsqueeze(0).float()
with torch.no_grad():
# 推理:BiRefNet 返回多尺度输出,取最后一个
preds = models["birefnet"](img_tensor)[-1].sigmoid().cpu()
# 还原尺寸
mask = torch.nn.functional.interpolate(
preds, size=(h, w), mode='bilinear', align_corners=False
)[0, 0]
mask_np = (mask.numpy() * 255).astype(np.uint8)
# 边缘美化
mask_pil = Image.fromarray(mask_np)
mask_pil = mask_pil.filter(ImageFilter.GaussianBlur(radius=0.8))
# 合成
result = img_pil.copy()
result.putalpha(mask_pil)
return result
except Exception as e:
import traceback
print(traceback.format_exc())
return f"BiRefNet 最终修复版失败: {str(e)}"
# 2. 文字识别 (OCR)
def process_ocr(input_img):
global models
if input_img is None:
return "未上传图片", "{}"
try:
# 1. 懒加载 EasyOCR (支持中英文)
if "ocr_reader" not in models:
print("正在初始化 OCR 阅读器...")
models["ocr_reader"] = easyocr.Reader(['ch_sim', 'en'], gpu=False)
# 2. 执行识别
img_np = np.array(input_img)
results = models["ocr_reader"].readtext(img_np)
# 3. 数据处理
full_text = "" # 存储纯文本
ocr_details = [] # 存储 JSON 详情
for (bbox, text, prob) in results:
if prob > 0.3:
# 累加纯文本
full_text += f"{text}\n"
# 整理坐标:bbox 格式为 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
# 转换为标准的 Python 列表以便 JSON 序列化
box_coords = [[int(x), int(y)] for x, y in bbox]
ocr_details.append({
"text": text,
"confidence": round(float(prob), 4),
"bbox": box_coords
})
# 格式化 JSON 字符串
json_output = json.dumps(ocr_details, ensure_ascii=False, indent=2)
if not full_text:
return "未识别到文字", "{}"
return full_text, json_output
except Exception as e:
return f"OCR 识别出错: {str(e)}", "{}"
def process_upscale(input_img):
if input_img is None: return None
# 更改为 2x 模型路径
model_path = "RealESRGAN_x2plus.pth"
if not os.path.exists(model_path):
return "请确认 RealESRGAN_x2plus.pth 已上传"
try:
if models["upscaler"] is None:
# 加载模型
loader = ModelLoader()
ckpt = loader.load_from_file(model_path)
models["upscaler"] = ckpt.model.to("cpu").eval()
# 限制线程数,防止 CPU 满载导致网页卡死
torch.set_num_threads(4)
# --- 性能预处理 ---
img_pil = input_img.convert("RGB")
w, h = img_pil.size
# 如果原图已经很大(比如超过 1280px),先缩小再 2x 超分
# 这样既保证了细节修复,又保证了速度
max_input = 1024
if max(w, h) > max_input:
img_pil.thumbnail((max_input, max_input), Image.LANCZOS)
img = np.array(img_pil).astype(np.float32) / 255.0
# HWC -> CHW
img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
img_tensor = img_tensor.float()
# --- 推理 ---
with torch.no_grad():
# 2x 模型的推理速度会比 4x 快非常多
output = models["upscaler"](img_tensor)
output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).numpy()
output = (output * 255).astype(np.uint8)
result_pil = Image.fromarray(output)
# --- 视觉补偿 ---
# 因为只放大了 2x,通过锐化来模拟更高倍数的清晰感
enhancer = ImageEnhance.Sharpness(result_pil)
result_pil = enhancer.enhance(1.2)
return result_pil
except Exception as e:
import traceback
print(traceback.format_exc())
return f"2x 超分失败: {str(e)}"
def fast_smart_sharpen(img):
"""
通用图像清晰化‘神药’:算法缩放 + 边缘锐化
"""
w, h = img.size
# 2倍放大
img = img.resize((w * 2, h * 2), Image.LANCZOS)
# Unsharp Mask 锐化:能显著增强物体边缘,让模糊的照片变‘硬朗’
img = img.filter(ImageFilter.UnsharpMask(radius=2, percent=120, threshold=2))
# 稍微拉一点对比度,让画面不那么灰
img = ImageEnhance.Contrast(img).enhance(1.05)
return img
# 4. 动漫化 (使用你上传的 AnimeGANv3_large_Ghibli_c1_e299.onnx)
def process_anime(input_img):
if input_img is None: return None
# 1. 严格检查模型文件
model_path = "AnimeGANv3_large_Ghibli_c1_e299.onnx"
if not os.path.exists(model_path):
return f"错误:未找到模型文件 {model_path},请确认已上传到根目录"
try:
# 2. 确保 Session 存在,如果不存在则立即加载
if models["anime_session"] is None:
print(f"正在加载模型: {model_path}...")
models["anime_session"] = ort.InferenceSession(
model_path,
providers=['CPUExecutionProvider']
)
# 3. 获取节点信息
session = models["anime_session"]
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 4. 预处理 (使用 1024 提升画质)
img_pil = input_img.convert("RGB")
w, h = img_pil.size
process_size = 1024
scale = process_size / max(w, h) if max(w, h) > process_size else 1.0
target_w, target_h = int((w * scale) // 32) * 32, int((h * scale) // 32) * 32
img_resized = img_pil.resize((target_w, target_h), Image.LANCZOS)
img_np = np.array(img_resized).astype(np.float32)
img_np = (img_np / 127.5) - 1.0
img_np = np.expand_dims(img_np, axis=0)
# 5. 执行推理
output = session.run([output_name], {input_name: img_np})[0]
# 6. 后处理与画质增强
output = np.squeeze(output)
output = (output + 1.0) * 127.5
output = np.clip(output, 0, 255).astype(np.uint8)
result_pil = Image.fromarray(output)
# 引入增强逻辑
result_pil = ImageEnhance.Sharpness(result_pil).enhance(1.5)
result_pil = ImageEnhance.Contrast(result_pil).enhance(1.1)
return result_pil.resize((w, h), Image.LANCZOS)
except Exception as e:
import traceback
# 打印详细错误到 Logs,并在 UI 返回简短错误
print(traceback.format_exc())
return f"推理失败: {str(e)}"
# --- Gradio 界面设计 ---
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# 🚀 ExtractIt AI 多功能平台")
with gr.Tabs():
with gr.TabItem("物体识别 (Identify Objects)"):
with gr.Row():
input_i = gr.Image(type="pil", label="上传图片")
output_t = gr.Textbox(label="识别详情", lines=10, interactive=False)
output_i = gr.Image(type="pil", label="可视化结果")
btn_i = gr.Button("开始识别")
btn_i.click(
fn=identify_objects,
inputs=input_i,
outputs=[output_i, output_t] # 第一个是图片,第二个是文字
)
with gr.TabItem("🖼️ 抠图 (RMBG)"):
with gr.Row():
in1 = gr.Image(type="pil", label="上传图片")
out1 = gr.Image(label="去背景结果")
btn1 = gr.Button("开始处理", variant="primary")
btn1.click(process_rembg, inputs=in1, outputs=out1)
with gr.TabItem("🎨 动漫化"):
with gr.Row():
in2 = gr.Image(type="pil", label="输入照片")
out2 = gr.Image(label="二次元化结果")
btn2 = gr.Button("风格转换", variant="primary")
btn2.click(process_anime, inputs=in2, outputs=out2)
with gr.TabItem("🔍 4x 高清修复"):
with gr.Row():
in3 = gr.Image(type="pil", label="低清图")
out3 = gr.Image(label="4倍超分结果")
btn3 = gr.Button("开始增强", variant="primary")
btn3.click(process_upscale, inputs=in3, outputs=out3)
with gr.TabItem("📝 文字识别 (OCR)"):
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="上传图片")
btn_ocr = gr.Button("开始识别文字", variant="primary")
with gr.Column():
# 输出识别到的纯文本
output_text = gr.Textbox(label="识别出的文字", lines=8)
# 输出 JSON 坐标数据,方便开发者或后续功能使用
output_json = gr.Code(label="文字区块坐标 (JSON)", language="json")
btn_ocr.click(
fn=process_ocr,
inputs=input_img,
outputs=[output_text, output_json]
)
# 启动队列以支持高并发排队
demo.queue().launch()