IR_expeiment / PART2 /jarvis_tools_api.py
hugaagg's picture
Upload folder using huggingface_hub
2ecc7ab verified
import os
import subprocess
import time
import sys
class JarvisRestorationToolkit:
def __init__(self):
# ================= 1. 基础配置 =================
self.root_dir = os.path.dirname(os.path.abspath(__file__))
self.workspace = os.path.join(self.root_dir, "Jarvis_Workspace")
os.makedirs(self.workspace, exist_ok=True)
# ================= 2. 环境路径注册 =================
self.envs = {
"ir_final": r"D:\conda\envs\ir_final\python.exe",
"swinir_env": r"D:\conda\envs\swinir_env\python.exe",
}
# ================= 3. 工具注册表 (已移除 NAFNet) =================
self.tools_config = {
# --- ID: 1 ---
"DarkIR": {
"env": "ir_final",
"cwd": os.path.join(self.root_dir, "DarkIR"),
"script": "worker_darkir.py",
"model": r"models/DarkIR_384.pt",
"desc": "低光增强 (旗舰版) - 适合夜景混合降质"
},
# --- ID: 2 ---
"SwinIR": {
"env": "swinir_env",
"cwd": os.path.join(self.root_dir, "SwinIR"),
"script": "worker_swinir.py",
"model": r"model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth",
"desc": "超分辨率 (x4) - 适合小图/模糊图"
},
# --- ID: 3 ---
"PromptIR": {
"env": "ir_final",
"cwd": os.path.join(self.root_dir, "PromptIR"),
"script": "worker_promptir.py",
"model": r"ckpt/model.ckpt",
"desc": "去雨去雾 (All-in-One) - 适合恶劣天气"
},
# --- ID: 4 ---
"CodeFormer": {
"env": "ir_final",
"cwd": os.path.join(self.root_dir, "CodeFormer"),
"script": "worker_codeformer.py",
"model": "None",
"desc": "人脸修复 - 适合模糊人像"
},
# --- [已移除] NAFNet ---
# --- ID: 5 (原ID:6) ---
"ZeroDCE": {
"env": "ir_final",
"cwd": os.path.join(self.root_dir, "Zero-DCE", "Zero-DCE_code"),
"script": "worker_zerodce.py",
"model": r"snapshots/Epoch99.pth",
"desc": "低光增强 (极速版) - 适合实时预览"
},
# --- ID: 6 (原ID:7) ---
"PowerPaint": {
"env": "ir_final",
"cwd": os.path.join(self.root_dir, "PowerPaint"),
"script": "worker_powerpaint.py",
"model": "runwayml/stable-diffusion-inpainting",
"desc": "智能补全 - 适合去水印/修补"
},
# --- ID: 7 (原ID:8) ---
"Restormer": {
"env": "ir_final",
"cwd": os.path.join(self.root_dir, "Restormer"),
"script": "worker_restormer_universal.py",
"model": r"Denoising/pretrained_models/gaussian_color_denoising_blind.pth",
"desc": "图像去噪 (Denoising) - 消除噪点/颗粒感"
}
}
def _run_worker(self, tool_key, input_path):
if tool_key not in self.tools_config:
print(f"❌ 内部错误:工具KEY '{tool_key}' 未定义")
return None
if not os.path.exists(input_path):
print(f"❌ 输入文件不存在: {input_path}")
return None
cfg = self.tools_config[tool_key]
if not os.path.exists(self.envs[cfg["env"]]):
print(f"❌ 环境路径错误: {self.envs[cfg['env']]}")
return None
# 自动生成文件名
timestamp = int(time.time() % 10000)
filename = os.path.basename(input_path)
name_no_ext, ext = os.path.splitext(filename)
output_name = f"{name_no_ext}_{tool_key}_{timestamp}{ext}"
output_path = os.path.join(self.workspace, output_name)
cmd = [
self.envs[cfg["env"]],
cfg["script"],
"-i", input_path,
"-o", output_path,
"-m", cfg["model"]
]
print(f"\n⚡ [{tool_key}] 执行中...")
try:
result = subprocess.run(cmd, cwd=cfg["cwd"], capture_output=True, text=True, encoding='utf-8')
if result.returncode == 0 and os.path.exists(output_path):
print(f"✅ 完成 -> {output_name}")
return output_path
else:
print(f"❌ 失败:\n{result.stderr}")
if result.stdout: print(f"--- STDOUT ---\n{result.stdout}")
return None
except Exception as e:
print(f"❌ 异常: {e}")
return None
if __name__ == "__main__":
toolkit = JarvisRestorationToolkit()
# 默认图片
current_img = r"G:\datasets\realblur_dataset_test\075_blur_1.png"
menu = {
"1": "DarkIR",
"2": "SwinIR",
"3": "PromptIR",
"4": "CodeFormer",
"5": "ZeroDCE", # 编号已前移
"6": "PowerPaint", # 编号已前移
"7": "Restormer" # 编号已前移
}
print("="*60)
print(" JarvisIR 工具链集成测试 (稳定版)")
print("="*60)
while True:
print(f"\n当前处理图片: {current_img}")
print("-" * 30)
for k, v in menu.items():
desc = toolkit.tools_config[v]['desc']
print(f"[{k}] {v.ljust(12)} : {desc}")
print("-" * 30)
user_input = input("请输入工具序列 (如 '1 2',输入 'test_all' 测试全部,'q' 退出): ").strip()
if user_input.lower() == 'q':
break
# 冒烟测试
if user_input.lower() in ['test_all', 'test']:
print("\n🚀 启动全工具冒烟测试...")
for k, tool_key in menu.items():
print(f"\n========== 测试 {k}. {tool_key} ==========")
toolkit._run_worker(tool_key, current_img)
print("\n✅ 全测试结束!(已移除不稳定工具)")
continue
steps = user_input.split()
valid_pipeline = []
for s in steps:
if s in menu:
valid_pipeline.append(menu[s])
else:
print(f"⚠️ 跳过无效输入: {s}")
if not valid_pipeline: continue
print(f"\n🚀 启动流水线: {' -> '.join(valid_pipeline)}")
temp_img = current_img
for i, tool_key in enumerate(valid_pipeline):
print(f"\n>>> Step {i+1}/{len(valid_pipeline)}: 调用 {tool_key}")
res = toolkit._run_worker(tool_key, temp_img)
if res:
temp_img = res
else:
print("🚨 流水线中断!")
break
print(f"\n🎁 最终结果已保存: {temp_img}")
if input("是否将此结果作为下一轮输入? (y/n): ").lower() == 'y':
current_img = temp_img