import gradio as gr import numpy as np import cv2 import onnxruntime as ort from PIL import Image, ImageEnhance, ImageFilter, ImageDraw, ImageFont from rembg import remove, new_session import easyocr import torch from spandrel import ModelLoader import os import onnxruntime as ort from transformers import pipeline, AutoModelForImageSegmentation, AutoConfig import json # --- 全局变量与模型缓存 --- models = { "birefnet": None, "bria_session": None, # 必须添加这一行 "ocr": None, "anime_session": None, "upscaler": None } def identify_objects(input_img): global models if input_img is None: return None, json.dumps({"error": "请上传图片"}, ensure_ascii=False) try: if "object_detector" not in models: from transformers import pipeline models["object_detector"] = pipeline( "object-detection", model="hustvl/yolos-tiny", device=-1 ) img_pil = input_img.convert("RGB") results = models["object_detector"](img_pil) draw = ImageDraw.Draw(img_pil) # 准备一个列表存储结构化数据 detections = [] count = 0 for res in results: if res["score"] > 0.5: count += 1 box = res["box"] label = res["label"] score = res["score"] # 绘制逻辑 draw.rectangle([box["xmin"], box["ymin"], box["xmax"], box["ymax"]], outline="red", width=4) # 将信息添加到列表 detections.append({ "id": count, "label": label, "confidence": round(score, 4), # 保留4位小数 "box": { "xmin": int(box["xmin"]), "ymin": int(box["ymin"]), "xmax": int(box["xmax"]), "ymax": int(box["ymax"]) } }) # 构造最终的 JSON 响应 output_json = { "status": "success", "count": count, "detections": detections } # 使用 json.dumps 转换为字符串,ensure_ascii=False 保证中文不乱码 return img_pil, json.dumps(output_json, ensure_ascii=False, indent=4) except Exception as e: error_json = { "status": "error", "message": str(e) } return input_img, json.dumps(error_json, ensure_ascii=False) def process_rembg(input_img): global models if input_img is None: return None try: # 加载模型逻辑 if models["birefnet"] is None: from transformers import AutoModelForImageSegmentation model_id = "ZhengPeng7/BiRefNet_lite" models["birefnet"] = AutoModelForImageSegmentation.from_pretrained( model_id, trust_remote_code=True ) # 补丁:防止某些版本 transformers 报错 if not hasattr(models["birefnet"], "all_tied_weights_keys"): models["birefnet"].all_tied_weights_keys = [] models["birefnet"].to("cpu").eval() # 图像处理 img_pil = input_img.convert("RGB") w, h = img_pil.size # 预处理:1024x1024 归一化 img_resized = img_pil.resize((1024, 1024), Image.BILINEAR) img_np = np.array(img_resized).astype(np.float32) / 255.0 # 标准 ImageNet 归一化参数 mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) img_np = (img_np - mean) / std img_np = img_np.transpose((2, 0, 1)) # HWC -> CHW # --- 终极修复:强制 float() --- img_tensor = torch.from_numpy(img_np).unsqueeze(0).float() with torch.no_grad(): # 推理:BiRefNet 返回多尺度输出,取最后一个 preds = models["birefnet"](img_tensor)[-1].sigmoid().cpu() # 还原尺寸 mask = torch.nn.functional.interpolate( preds, size=(h, w), mode='bilinear', align_corners=False )[0, 0] mask_np = (mask.numpy() * 255).astype(np.uint8) # 边缘美化 mask_pil = Image.fromarray(mask_np) mask_pil = mask_pil.filter(ImageFilter.GaussianBlur(radius=0.8)) # 合成 result = img_pil.copy() result.putalpha(mask_pil) return result except Exception as e: import traceback print(traceback.format_exc()) return f"BiRefNet 最终修复版失败: {str(e)}" # 2. 文字识别 (OCR) def process_ocr(input_img): global models if input_img is None: return "未上传图片", "{}" try: # 1. 懒加载 EasyOCR (支持中英文) if "ocr_reader" not in models: print("正在初始化 OCR 阅读器...") models["ocr_reader"] = easyocr.Reader(['ch_sim', 'en'], gpu=False) # 2. 执行识别 img_np = np.array(input_img) results = models["ocr_reader"].readtext(img_np) # 3. 数据处理 full_text = "" # 存储纯文本 ocr_details = [] # 存储 JSON 详情 for (bbox, text, prob) in results: if prob > 0.3: # 累加纯文本 full_text += f"{text}\n" # 整理坐标:bbox 格式为 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] # 转换为标准的 Python 列表以便 JSON 序列化 box_coords = [[int(x), int(y)] for x, y in bbox] ocr_details.append({ "text": text, "confidence": round(float(prob), 4), "bbox": box_coords }) # 格式化 JSON 字符串 json_output = json.dumps(ocr_details, ensure_ascii=False, indent=2) if not full_text: return "未识别到文字", "{}" return full_text, json_output except Exception as e: return f"OCR 识别出错: {str(e)}", "{}" def process_upscale(input_img): if input_img is None: return None # 更改为 2x 模型路径 model_path = "RealESRGAN_x2plus.pth" if not os.path.exists(model_path): return "请确认 RealESRGAN_x2plus.pth 已上传" try: if models["upscaler"] is None: # 加载模型 loader = ModelLoader() ckpt = loader.load_from_file(model_path) models["upscaler"] = ckpt.model.to("cpu").eval() # 限制线程数,防止 CPU 满载导致网页卡死 torch.set_num_threads(4) # --- 性能预处理 --- img_pil = input_img.convert("RGB") w, h = img_pil.size # 如果原图已经很大(比如超过 1280px),先缩小再 2x 超分 # 这样既保证了细节修复,又保证了速度 max_input = 1024 if max(w, h) > max_input: img_pil.thumbnail((max_input, max_input), Image.LANCZOS) img = np.array(img_pil).astype(np.float32) / 255.0 # HWC -> CHW img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) img_tensor = img_tensor.float() # --- 推理 --- with torch.no_grad(): # 2x 模型的推理速度会比 4x 快非常多 output = models["upscaler"](img_tensor) output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).numpy() output = (output * 255).astype(np.uint8) result_pil = Image.fromarray(output) # --- 视觉补偿 --- # 因为只放大了 2x,通过锐化来模拟更高倍数的清晰感 enhancer = ImageEnhance.Sharpness(result_pil) result_pil = enhancer.enhance(1.2) return result_pil except Exception as e: import traceback print(traceback.format_exc()) return f"2x 超分失败: {str(e)}" def fast_smart_sharpen(img): """ 通用图像清晰化‘神药’:算法缩放 + 边缘锐化 """ w, h = img.size # 2倍放大 img = img.resize((w * 2, h * 2), Image.LANCZOS) # Unsharp Mask 锐化:能显著增强物体边缘,让模糊的照片变‘硬朗’ img = img.filter(ImageFilter.UnsharpMask(radius=2, percent=120, threshold=2)) # 稍微拉一点对比度,让画面不那么灰 img = ImageEnhance.Contrast(img).enhance(1.05) return img # 4. 动漫化 (使用你上传的 AnimeGANv3_large_Ghibli_c1_e299.onnx) def process_anime(input_img): if input_img is None: return None # 1. 严格检查模型文件 model_path = "AnimeGANv3_large_Ghibli_c1_e299.onnx" if not os.path.exists(model_path): return f"错误:未找到模型文件 {model_path},请确认已上传到根目录" try: # 2. 确保 Session 存在,如果不存在则立即加载 if models["anime_session"] is None: print(f"正在加载模型: {model_path}...") models["anime_session"] = ort.InferenceSession( model_path, providers=['CPUExecutionProvider'] ) # 3. 获取节点信息 session = models["anime_session"] input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name # 4. 预处理 (使用 1024 提升画质) img_pil = input_img.convert("RGB") w, h = img_pil.size process_size = 1024 scale = process_size / max(w, h) if max(w, h) > process_size else 1.0 target_w, target_h = int((w * scale) // 32) * 32, int((h * scale) // 32) * 32 img_resized = img_pil.resize((target_w, target_h), Image.LANCZOS) img_np = np.array(img_resized).astype(np.float32) img_np = (img_np / 127.5) - 1.0 img_np = np.expand_dims(img_np, axis=0) # 5. 执行推理 output = session.run([output_name], {input_name: img_np})[0] # 6. 后处理与画质增强 output = np.squeeze(output) output = (output + 1.0) * 127.5 output = np.clip(output, 0, 255).astype(np.uint8) result_pil = Image.fromarray(output) # 引入增强逻辑 result_pil = ImageEnhance.Sharpness(result_pil).enhance(1.5) result_pil = ImageEnhance.Contrast(result_pil).enhance(1.1) return result_pil.resize((w, h), Image.LANCZOS) except Exception as e: import traceback # 打印详细错误到 Logs,并在 UI 返回简短错误 print(traceback.format_exc()) return f"推理失败: {str(e)}" # --- Gradio 界面设计 --- with gr.Blocks(theme=gr.themes.Default()) as demo: gr.Markdown("# 🚀 ExtractIt AI 多功能平台") with gr.Tabs(): with gr.TabItem("物体识别 (Identify Objects)"): with gr.Row(): input_i = gr.Image(type="pil", label="上传图片") output_t = gr.Textbox(label="识别详情", lines=10, interactive=False) output_i = gr.Image(type="pil", label="可视化结果") btn_i = gr.Button("开始识别") btn_i.click( fn=identify_objects, inputs=input_i, outputs=[output_i, output_t] # 第一个是图片,第二个是文字 ) with gr.TabItem("🖼️ 抠图 (RMBG)"): with gr.Row(): in1 = gr.Image(type="pil", label="上传图片") out1 = gr.Image(label="去背景结果") btn1 = gr.Button("开始处理", variant="primary") btn1.click(process_rembg, inputs=in1, outputs=out1) with gr.TabItem("🎨 动漫化"): with gr.Row(): in2 = gr.Image(type="pil", label="输入照片") out2 = gr.Image(label="二次元化结果") btn2 = gr.Button("风格转换", variant="primary") btn2.click(process_anime, inputs=in2, outputs=out2) with gr.TabItem("🔍 4x 高清修复"): with gr.Row(): in3 = gr.Image(type="pil", label="低清图") out3 = gr.Image(label="4倍超分结果") btn3 = gr.Button("开始增强", variant="primary") btn3.click(process_upscale, inputs=in3, outputs=out3) with gr.TabItem("📝 文字识别 (OCR)"): with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label="上传图片") btn_ocr = gr.Button("开始识别文字", variant="primary") with gr.Column(): # 输出识别到的纯文本 output_text = gr.Textbox(label="识别出的文字", lines=8) # 输出 JSON 坐标数据,方便开发者或后续功能使用 output_json = gr.Code(label="文字区块坐标 (JSON)", language="json") btn_ocr.click( fn=process_ocr, inputs=input_img, outputs=[output_text, output_json] ) # 启动队列以支持高并发排队 demo.queue().launch()