Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import onnxruntime as ort | |
| from PIL import Image, ImageEnhance, ImageFilter, ImageDraw, ImageFont | |
| from rembg import remove, new_session | |
| import easyocr | |
| import torch | |
| from spandrel import ModelLoader | |
| import os | |
| import onnxruntime as ort | |
| from transformers import pipeline, AutoModelForImageSegmentation, AutoConfig | |
| import json | |
| # --- 全局变量与模型缓存 --- | |
| models = { | |
| "birefnet": None, | |
| "bria_session": None, # 必须添加这一行 | |
| "ocr": None, | |
| "anime_session": None, | |
| "upscaler": None | |
| } | |
| def identify_objects(input_img): | |
| global models | |
| if input_img is None: | |
| return None, json.dumps({"error": "请上传图片"}, ensure_ascii=False) | |
| try: | |
| if "object_detector" not in models: | |
| from transformers import pipeline | |
| models["object_detector"] = pipeline( | |
| "object-detection", | |
| model="hustvl/yolos-tiny", | |
| device=-1 | |
| ) | |
| img_pil = input_img.convert("RGB") | |
| results = models["object_detector"](img_pil) | |
| draw = ImageDraw.Draw(img_pil) | |
| # 准备一个列表存储结构化数据 | |
| detections = [] | |
| count = 0 | |
| for res in results: | |
| if res["score"] > 0.5: | |
| count += 1 | |
| box = res["box"] | |
| label = res["label"] | |
| score = res["score"] | |
| # 绘制逻辑 | |
| draw.rectangle([box["xmin"], box["ymin"], box["xmax"], box["ymax"]], outline="red", width=4) | |
| # 将信息添加到列表 | |
| detections.append({ | |
| "id": count, | |
| "label": label, | |
| "confidence": round(score, 4), # 保留4位小数 | |
| "box": { | |
| "xmin": int(box["xmin"]), | |
| "ymin": int(box["ymin"]), | |
| "xmax": int(box["xmax"]), | |
| "ymax": int(box["ymax"]) | |
| } | |
| }) | |
| # 构造最终的 JSON 响应 | |
| output_json = { | |
| "status": "success", | |
| "count": count, | |
| "detections": detections | |
| } | |
| # 使用 json.dumps 转换为字符串,ensure_ascii=False 保证中文不乱码 | |
| return img_pil, json.dumps(output_json, ensure_ascii=False, indent=4) | |
| except Exception as e: | |
| error_json = { | |
| "status": "error", | |
| "message": str(e) | |
| } | |
| return input_img, json.dumps(error_json, ensure_ascii=False) | |
| def process_rembg(input_img): | |
| global models | |
| if input_img is None: return None | |
| try: | |
| # 加载模型逻辑 | |
| if models["birefnet"] is None: | |
| from transformers import AutoModelForImageSegmentation | |
| model_id = "ZhengPeng7/BiRefNet_lite" | |
| models["birefnet"] = AutoModelForImageSegmentation.from_pretrained( | |
| model_id, trust_remote_code=True | |
| ) | |
| # 补丁:防止某些版本 transformers 报错 | |
| if not hasattr(models["birefnet"], "all_tied_weights_keys"): | |
| models["birefnet"].all_tied_weights_keys = [] | |
| models["birefnet"].to("cpu").eval() | |
| # 图像处理 | |
| img_pil = input_img.convert("RGB") | |
| w, h = img_pil.size | |
| # 预处理:1024x1024 归一化 | |
| img_resized = img_pil.resize((1024, 1024), Image.BILINEAR) | |
| img_np = np.array(img_resized).astype(np.float32) / 255.0 | |
| # 标准 ImageNet 归一化参数 | |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| img_np = (img_np - mean) / std | |
| img_np = img_np.transpose((2, 0, 1)) # HWC -> CHW | |
| # --- 终极修复:强制 float() --- | |
| img_tensor = torch.from_numpy(img_np).unsqueeze(0).float() | |
| with torch.no_grad(): | |
| # 推理:BiRefNet 返回多尺度输出,取最后一个 | |
| preds = models["birefnet"](img_tensor)[-1].sigmoid().cpu() | |
| # 还原尺寸 | |
| mask = torch.nn.functional.interpolate( | |
| preds, size=(h, w), mode='bilinear', align_corners=False | |
| )[0, 0] | |
| mask_np = (mask.numpy() * 255).astype(np.uint8) | |
| # 边缘美化 | |
| mask_pil = Image.fromarray(mask_np) | |
| mask_pil = mask_pil.filter(ImageFilter.GaussianBlur(radius=0.8)) | |
| # 合成 | |
| result = img_pil.copy() | |
| result.putalpha(mask_pil) | |
| return result | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| return f"BiRefNet 最终修复版失败: {str(e)}" | |
| # 2. 文字识别 (OCR) | |
| def process_ocr(input_img): | |
| global models | |
| if input_img is None: | |
| return "未上传图片", "{}" | |
| try: | |
| # 1. 懒加载 EasyOCR (支持中英文) | |
| if "ocr_reader" not in models: | |
| print("正在初始化 OCR 阅读器...") | |
| models["ocr_reader"] = easyocr.Reader(['ch_sim', 'en'], gpu=False) | |
| # 2. 执行识别 | |
| img_np = np.array(input_img) | |
| results = models["ocr_reader"].readtext(img_np) | |
| # 3. 数据处理 | |
| full_text = "" # 存储纯文本 | |
| ocr_details = [] # 存储 JSON 详情 | |
| for (bbox, text, prob) in results: | |
| if prob > 0.3: | |
| # 累加纯文本 | |
| full_text += f"{text}\n" | |
| # 整理坐标:bbox 格式为 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] | |
| # 转换为标准的 Python 列表以便 JSON 序列化 | |
| box_coords = [[int(x), int(y)] for x, y in bbox] | |
| ocr_details.append({ | |
| "text": text, | |
| "confidence": round(float(prob), 4), | |
| "bbox": box_coords | |
| }) | |
| # 格式化 JSON 字符串 | |
| json_output = json.dumps(ocr_details, ensure_ascii=False, indent=2) | |
| if not full_text: | |
| return "未识别到文字", "{}" | |
| return full_text, json_output | |
| except Exception as e: | |
| return f"OCR 识别出错: {str(e)}", "{}" | |
| def process_upscale(input_img): | |
| if input_img is None: return None | |
| # 更改为 2x 模型路径 | |
| model_path = "RealESRGAN_x2plus.pth" | |
| if not os.path.exists(model_path): | |
| return "请确认 RealESRGAN_x2plus.pth 已上传" | |
| try: | |
| if models["upscaler"] is None: | |
| # 加载模型 | |
| loader = ModelLoader() | |
| ckpt = loader.load_from_file(model_path) | |
| models["upscaler"] = ckpt.model.to("cpu").eval() | |
| # 限制线程数,防止 CPU 满载导致网页卡死 | |
| torch.set_num_threads(4) | |
| # --- 性能预处理 --- | |
| img_pil = input_img.convert("RGB") | |
| w, h = img_pil.size | |
| # 如果原图已经很大(比如超过 1280px),先缩小再 2x 超分 | |
| # 这样既保证了细节修复,又保证了速度 | |
| max_input = 1024 | |
| if max(w, h) > max_input: | |
| img_pil.thumbnail((max_input, max_input), Image.LANCZOS) | |
| img = np.array(img_pil).astype(np.float32) / 255.0 | |
| # HWC -> CHW | |
| img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) | |
| img_tensor = img_tensor.float() | |
| # --- 推理 --- | |
| with torch.no_grad(): | |
| # 2x 模型的推理速度会比 4x 快非常多 | |
| output = models["upscaler"](img_tensor) | |
| output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).numpy() | |
| output = (output * 255).astype(np.uint8) | |
| result_pil = Image.fromarray(output) | |
| # --- 视觉补偿 --- | |
| # 因为只放大了 2x,通过锐化来模拟更高倍数的清晰感 | |
| enhancer = ImageEnhance.Sharpness(result_pil) | |
| result_pil = enhancer.enhance(1.2) | |
| return result_pil | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| return f"2x 超分失败: {str(e)}" | |
| def fast_smart_sharpen(img): | |
| """ | |
| 通用图像清晰化‘神药’:算法缩放 + 边缘锐化 | |
| """ | |
| w, h = img.size | |
| # 2倍放大 | |
| img = img.resize((w * 2, h * 2), Image.LANCZOS) | |
| # Unsharp Mask 锐化:能显著增强物体边缘,让模糊的照片变‘硬朗’ | |
| img = img.filter(ImageFilter.UnsharpMask(radius=2, percent=120, threshold=2)) | |
| # 稍微拉一点对比度,让画面不那么灰 | |
| img = ImageEnhance.Contrast(img).enhance(1.05) | |
| return img | |
| # 4. 动漫化 (使用你上传的 AnimeGANv3_large_Ghibli_c1_e299.onnx) | |
| def process_anime(input_img): | |
| if input_img is None: return None | |
| # 1. 严格检查模型文件 | |
| model_path = "AnimeGANv3_large_Ghibli_c1_e299.onnx" | |
| if not os.path.exists(model_path): | |
| return f"错误:未找到模型文件 {model_path},请确认已上传到根目录" | |
| try: | |
| # 2. 确保 Session 存在,如果不存在则立即加载 | |
| if models["anime_session"] is None: | |
| print(f"正在加载模型: {model_path}...") | |
| models["anime_session"] = ort.InferenceSession( | |
| model_path, | |
| providers=['CPUExecutionProvider'] | |
| ) | |
| # 3. 获取节点信息 | |
| session = models["anime_session"] | |
| input_name = session.get_inputs()[0].name | |
| output_name = session.get_outputs()[0].name | |
| # 4. 预处理 (使用 1024 提升画质) | |
| img_pil = input_img.convert("RGB") | |
| w, h = img_pil.size | |
| process_size = 1024 | |
| scale = process_size / max(w, h) if max(w, h) > process_size else 1.0 | |
| target_w, target_h = int((w * scale) // 32) * 32, int((h * scale) // 32) * 32 | |
| img_resized = img_pil.resize((target_w, target_h), Image.LANCZOS) | |
| img_np = np.array(img_resized).astype(np.float32) | |
| img_np = (img_np / 127.5) - 1.0 | |
| img_np = np.expand_dims(img_np, axis=0) | |
| # 5. 执行推理 | |
| output = session.run([output_name], {input_name: img_np})[0] | |
| # 6. 后处理与画质增强 | |
| output = np.squeeze(output) | |
| output = (output + 1.0) * 127.5 | |
| output = np.clip(output, 0, 255).astype(np.uint8) | |
| result_pil = Image.fromarray(output) | |
| # 引入增强逻辑 | |
| result_pil = ImageEnhance.Sharpness(result_pil).enhance(1.5) | |
| result_pil = ImageEnhance.Contrast(result_pil).enhance(1.1) | |
| return result_pil.resize((w, h), Image.LANCZOS) | |
| except Exception as e: | |
| import traceback | |
| # 打印详细错误到 Logs,并在 UI 返回简短错误 | |
| print(traceback.format_exc()) | |
| return f"推理失败: {str(e)}" | |
| # --- Gradio 界面设计 --- | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| gr.Markdown("# 🚀 ExtractIt AI 多功能平台") | |
| with gr.Tabs(): | |
| with gr.TabItem("物体识别 (Identify Objects)"): | |
| with gr.Row(): | |
| input_i = gr.Image(type="pil", label="上传图片") | |
| output_t = gr.Textbox(label="识别详情", lines=10, interactive=False) | |
| output_i = gr.Image(type="pil", label="可视化结果") | |
| btn_i = gr.Button("开始识别") | |
| btn_i.click( | |
| fn=identify_objects, | |
| inputs=input_i, | |
| outputs=[output_i, output_t] # 第一个是图片,第二个是文字 | |
| ) | |
| with gr.TabItem("🖼️ 抠图 (RMBG)"): | |
| with gr.Row(): | |
| in1 = gr.Image(type="pil", label="上传图片") | |
| out1 = gr.Image(label="去背景结果") | |
| btn1 = gr.Button("开始处理", variant="primary") | |
| btn1.click(process_rembg, inputs=in1, outputs=out1) | |
| with gr.TabItem("🎨 动漫化"): | |
| with gr.Row(): | |
| in2 = gr.Image(type="pil", label="输入照片") | |
| out2 = gr.Image(label="二次元化结果") | |
| btn2 = gr.Button("风格转换", variant="primary") | |
| btn2.click(process_anime, inputs=in2, outputs=out2) | |
| with gr.TabItem("🔍 4x 高清修复"): | |
| with gr.Row(): | |
| in3 = gr.Image(type="pil", label="低清图") | |
| out3 = gr.Image(label="4倍超分结果") | |
| btn3 = gr.Button("开始增强", variant="primary") | |
| btn3.click(process_upscale, inputs=in3, outputs=out3) | |
| with gr.TabItem("📝 文字识别 (OCR)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image(type="pil", label="上传图片") | |
| btn_ocr = gr.Button("开始识别文字", variant="primary") | |
| with gr.Column(): | |
| # 输出识别到的纯文本 | |
| output_text = gr.Textbox(label="识别出的文字", lines=8) | |
| # 输出 JSON 坐标数据,方便开发者或后续功能使用 | |
| output_json = gr.Code(label="文字区块坐标 (JSON)", language="json") | |
| btn_ocr.click( | |
| fn=process_ocr, | |
| inputs=input_img, | |
| outputs=[output_text, output_json] | |
| ) | |
| # 启动队列以支持高并发排队 | |
| demo.queue().launch() |