| import os
|
| import subprocess
|
| import time
|
| import sys
|
|
|
| class JarvisRestorationToolkit:
|
| def __init__(self):
|
|
|
| self.root_dir = os.path.dirname(os.path.abspath(__file__))
|
| self.workspace = os.path.join(self.root_dir, "Jarvis_Workspace")
|
| os.makedirs(self.workspace, exist_ok=True)
|
|
|
|
|
| self.envs = {
|
| "ir_final": r"D:\conda\envs\ir_final\python.exe",
|
| "swinir_env": r"D:\conda\envs\swinir_env\python.exe",
|
| }
|
|
|
|
|
| self.tools_config = {
|
|
|
| "DarkIR": {
|
| "env": "ir_final",
|
| "cwd": os.path.join(self.root_dir, "DarkIR"),
|
| "script": "worker_darkir.py",
|
| "model": r"models/DarkIR_384.pt",
|
| "desc": "低光增强 (旗舰版) - 适合夜景混合降质"
|
| },
|
|
|
| "SwinIR": {
|
| "env": "swinir_env",
|
| "cwd": os.path.join(self.root_dir, "SwinIR"),
|
| "script": "worker_swinir.py",
|
| "model": r"model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth",
|
| "desc": "超分辨率 (x4) - 适合小图/模糊图"
|
| },
|
|
|
| "PromptIR": {
|
| "env": "ir_final",
|
| "cwd": os.path.join(self.root_dir, "PromptIR"),
|
| "script": "worker_promptir.py",
|
| "model": r"ckpt/model.ckpt",
|
| "desc": "去雨去雾 (All-in-One) - 适合恶劣天气"
|
| },
|
|
|
| "CodeFormer": {
|
| "env": "ir_final",
|
| "cwd": os.path.join(self.root_dir, "CodeFormer"),
|
| "script": "worker_codeformer.py",
|
| "model": "None",
|
| "desc": "人脸修复 - 适合模糊人像"
|
| },
|
|
|
|
|
|
|
| "ZeroDCE": {
|
| "env": "ir_final",
|
| "cwd": os.path.join(self.root_dir, "Zero-DCE", "Zero-DCE_code"),
|
| "script": "worker_zerodce.py",
|
| "model": r"snapshots/Epoch99.pth",
|
| "desc": "低光增强 (极速版) - 适合实时预览"
|
| },
|
|
|
| "PowerPaint": {
|
| "env": "ir_final",
|
| "cwd": os.path.join(self.root_dir, "PowerPaint"),
|
| "script": "worker_powerpaint.py",
|
| "model": "runwayml/stable-diffusion-inpainting",
|
| "desc": "智能补全 - 适合去水印/修补"
|
| },
|
|
|
| "Restormer": {
|
| "env": "ir_final",
|
| "cwd": os.path.join(self.root_dir, "Restormer"),
|
| "script": "worker_restormer_universal.py",
|
| "model": r"Denoising/pretrained_models/gaussian_color_denoising_blind.pth",
|
| "desc": "图像去噪 (Denoising) - 消除噪点/颗粒感"
|
| }
|
| }
|
|
|
| def _run_worker(self, tool_key, input_path):
|
| if tool_key not in self.tools_config:
|
| print(f"❌ 内部错误:工具KEY '{tool_key}' 未定义")
|
| return None
|
|
|
| if not os.path.exists(input_path):
|
| print(f"❌ 输入文件不存在: {input_path}")
|
| return None
|
|
|
| cfg = self.tools_config[tool_key]
|
|
|
| if not os.path.exists(self.envs[cfg["env"]]):
|
| print(f"❌ 环境路径错误: {self.envs[cfg['env']]}")
|
| return None
|
|
|
|
|
| timestamp = int(time.time() % 10000)
|
| filename = os.path.basename(input_path)
|
| name_no_ext, ext = os.path.splitext(filename)
|
| output_name = f"{name_no_ext}_{tool_key}_{timestamp}{ext}"
|
| output_path = os.path.join(self.workspace, output_name)
|
|
|
| cmd = [
|
| self.envs[cfg["env"]],
|
| cfg["script"],
|
| "-i", input_path,
|
| "-o", output_path,
|
| "-m", cfg["model"]
|
| ]
|
|
|
| print(f"\n⚡ [{tool_key}] 执行中...")
|
| try:
|
| result = subprocess.run(cmd, cwd=cfg["cwd"], capture_output=True, text=True, encoding='utf-8')
|
|
|
| if result.returncode == 0 and os.path.exists(output_path):
|
| print(f"✅ 完成 -> {output_name}")
|
| return output_path
|
| else:
|
| print(f"❌ 失败:\n{result.stderr}")
|
| if result.stdout: print(f"--- STDOUT ---\n{result.stdout}")
|
| return None
|
| except Exception as e:
|
| print(f"❌ 异常: {e}")
|
| return None
|
|
|
| if __name__ == "__main__":
|
| toolkit = JarvisRestorationToolkit()
|
|
|
|
|
| current_img = r"G:\datasets\realblur_dataset_test\075_blur_1.png"
|
|
|
| menu = {
|
| "1": "DarkIR",
|
| "2": "SwinIR",
|
| "3": "PromptIR",
|
| "4": "CodeFormer",
|
| "5": "ZeroDCE",
|
| "6": "PowerPaint",
|
| "7": "Restormer"
|
| }
|
|
|
| print("="*60)
|
| print(" JarvisIR 工具链集成测试 (稳定版)")
|
| print("="*60)
|
|
|
| while True:
|
| print(f"\n当前处理图片: {current_img}")
|
| print("-" * 30)
|
| for k, v in menu.items():
|
| desc = toolkit.tools_config[v]['desc']
|
| print(f"[{k}] {v.ljust(12)} : {desc}")
|
| print("-" * 30)
|
|
|
| user_input = input("请输入工具序列 (如 '1 2',输入 'test_all' 测试全部,'q' 退出): ").strip()
|
|
|
| if user_input.lower() == 'q':
|
| break
|
|
|
|
|
| if user_input.lower() in ['test_all', 'test']:
|
| print("\n🚀 启动全工具冒烟测试...")
|
| for k, tool_key in menu.items():
|
| print(f"\n========== 测试 {k}. {tool_key} ==========")
|
| toolkit._run_worker(tool_key, current_img)
|
| print("\n✅ 全测试结束!(已移除不稳定工具)")
|
| continue
|
|
|
| steps = user_input.split()
|
| valid_pipeline = []
|
| for s in steps:
|
| if s in menu:
|
| valid_pipeline.append(menu[s])
|
| else:
|
| print(f"⚠️ 跳过无效输入: {s}")
|
|
|
| if not valid_pipeline: continue
|
|
|
| print(f"\n🚀 启动流水线: {' -> '.join(valid_pipeline)}")
|
| temp_img = current_img
|
| for i, tool_key in enumerate(valid_pipeline):
|
| print(f"\n>>> Step {i+1}/{len(valid_pipeline)}: 调用 {tool_key}")
|
| res = toolkit._run_worker(tool_key, temp_img)
|
| if res:
|
| temp_img = res
|
| else:
|
| print("🚨 流水线中断!")
|
| break
|
|
|
| print(f"\n🎁 最终结果已保存: {temp_img}")
|
| if input("是否将此结果作为下一轮输入? (y/n): ").lower() == 'y':
|
| current_img = temp_img |