Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,8 +14,7 @@ import cv2
|
|
| 14 |
from typing import List, Tuple, Optional
|
| 15 |
import sys
|
| 16 |
|
| 17 |
-
|
| 18 |
-
# ------------------ 修复SAM2模块路径问题 ------------------
|
| 19 |
def add_sam2_to_path():
|
| 20 |
"""将SAM2安装目录添加到Python路径"""
|
| 21 |
sam2_dir = os.path.abspath("third_party/sam2")
|
|
@@ -23,7 +22,6 @@ def add_sam2_to_path():
|
|
| 23 |
sys.path.insert(0, sam2_dir)
|
| 24 |
return sam2_dir
|
| 25 |
|
| 26 |
-
# ------------------ 安装SAM2依赖 ------------------
|
| 27 |
def install_sam2():
|
| 28 |
"""检查并安装SAM2及其依赖"""
|
| 29 |
sam2_dir = "third_party/sam2"
|
|
@@ -31,7 +29,6 @@ def install_sam2():
|
|
| 31 |
print("Installing SAM2...")
|
| 32 |
os.makedirs("third_party", exist_ok=True)
|
| 33 |
|
| 34 |
-
# 克隆仓库(添加--recursive以防有子模块)
|
| 35 |
subprocess.run([
|
| 36 |
"git", "clone",
|
| 37 |
"--recursive",
|
|
@@ -39,15 +36,10 @@ def install_sam2():
|
|
| 39 |
sam2_dir
|
| 40 |
], check=True)
|
| 41 |
|
| 42 |
-
# 切换到sam2目录安装
|
| 43 |
original_dir = os.getcwd()
|
| 44 |
try:
|
| 45 |
os.chdir(sam2_dir)
|
| 46 |
-
|
| 47 |
-
# 先安装核心依赖
|
| 48 |
-
# subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True)
|
| 49 |
-
|
| 50 |
-
# 以可编辑模式安装SAM2
|
| 51 |
subprocess.run(["pip", "install", "-e", "."], check=True)
|
| 52 |
|
| 53 |
|
|
@@ -62,27 +54,21 @@ def install_sam2():
|
|
| 62 |
print("SAM2 already exists, skipping installation.")
|
| 63 |
|
| 64 |
|
| 65 |
-
# 1. 安装SAM2
|
| 66 |
install_sam2()
|
| 67 |
|
| 68 |
-
# 2. 确保路径正确
|
| 69 |
sam2_dir = add_sam2_to_path()
|
| 70 |
|
| 71 |
-
# 3. 现在可以安全导入SAM2模块
|
| 72 |
from sam2.build_sam import build_sam2
|
| 73 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 74 |
print("🎉 SAM2 modules imported successfully!")
|
| 75 |
|
| 76 |
-
|
| 77 |
-
# 使用相对路径
|
| 78 |
MODEL_PATH = "geshang/Seg-R1-COD"
|
| 79 |
SAM_CHECKPOINT = "sam2_weights/sam2.1_hiera_large.pt"
|
| 80 |
|
| 81 |
-
# 自动检测设备
|
| 82 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 83 |
RESIZE_SIZE = (768, 768)
|
| 84 |
|
| 85 |
-
# 加载Qwen模型
|
| 86 |
try:
|
| 87 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 88 |
MODEL_PATH,
|
|
@@ -93,7 +79,6 @@ try:
|
|
| 93 |
print(f"Qwen model loaded on {DEVICE}")
|
| 94 |
except Exception as e:
|
| 95 |
print(f"Error loading Qwen model: {e}")
|
| 96 |
-
# 创建虚拟模型以便继续运行
|
| 97 |
model = None
|
| 98 |
processor = None
|
| 99 |
|
|
@@ -145,10 +130,8 @@ class CustomSAMWrapper:
|
|
| 145 |
print(f"SAM prediction error: {e}")
|
| 146 |
return np.zeros((image.height, image.width), dtype=bool), 0.0
|
| 147 |
|
| 148 |
-
# 初始化SAM包装器
|
| 149 |
sam_wrapper = CustomSAMWrapper(SAM_CHECKPOINT, device=DEVICE)
|
| 150 |
|
| 151 |
-
# ------------------ 推理相关函数 ------------------
|
| 152 |
|
| 153 |
def parse_custom_format(content: str):
|
| 154 |
point_pattern = r"<points>\s*(\[\s*(?:\[\s*\d+\s*,\s*\d+\s*\]\s*,?\s*)+\])\s*</points>"
|
|
@@ -331,7 +314,6 @@ def run_pipeline(image: PILImage.Image, prompt: str):
|
|
| 331 |
print(f"Pipeline error: {e}")
|
| 332 |
return f"Error processing request: {str(e)}", None
|
| 333 |
|
| 334 |
-
# ------------------ 启动 Gradio ------------------
|
| 335 |
|
| 336 |
with gr.Blocks(title="Seg-R1") as demo:
|
| 337 |
gr.Markdown("# Seg-R1: Visual Segmentation Assistant")
|
|
|
|
| 14 |
from typing import List, Tuple, Optional
|
| 15 |
import sys
|
| 16 |
|
| 17 |
+
|
|
|
|
| 18 |
def add_sam2_to_path():
|
| 19 |
"""将SAM2安装目录添加到Python路径"""
|
| 20 |
sam2_dir = os.path.abspath("third_party/sam2")
|
|
|
|
| 22 |
sys.path.insert(0, sam2_dir)
|
| 23 |
return sam2_dir
|
| 24 |
|
|
|
|
| 25 |
def install_sam2():
|
| 26 |
"""检查并安装SAM2及其依赖"""
|
| 27 |
sam2_dir = "third_party/sam2"
|
|
|
|
| 29 |
print("Installing SAM2...")
|
| 30 |
os.makedirs("third_party", exist_ok=True)
|
| 31 |
|
|
|
|
| 32 |
subprocess.run([
|
| 33 |
"git", "clone",
|
| 34 |
"--recursive",
|
|
|
|
| 36 |
sam2_dir
|
| 37 |
], check=True)
|
| 38 |
|
|
|
|
| 39 |
original_dir = os.getcwd()
|
| 40 |
try:
|
| 41 |
os.chdir(sam2_dir)
|
| 42 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
subprocess.run(["pip", "install", "-e", "."], check=True)
|
| 44 |
|
| 45 |
|
|
|
|
| 54 |
print("SAM2 already exists, skipping installation.")
|
| 55 |
|
| 56 |
|
|
|
|
| 57 |
install_sam2()
|
| 58 |
|
|
|
|
| 59 |
sam2_dir = add_sam2_to_path()
|
| 60 |
|
|
|
|
| 61 |
from sam2.build_sam import build_sam2
|
| 62 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 63 |
print("🎉 SAM2 modules imported successfully!")
|
| 64 |
|
| 65 |
+
|
|
|
|
| 66 |
MODEL_PATH = "geshang/Seg-R1-COD"
|
| 67 |
SAM_CHECKPOINT = "sam2_weights/sam2.1_hiera_large.pt"
|
| 68 |
|
|
|
|
| 69 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
RESIZE_SIZE = (768, 768)
|
| 71 |
|
|
|
|
| 72 |
try:
|
| 73 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 74 |
MODEL_PATH,
|
|
|
|
| 79 |
print(f"Qwen model loaded on {DEVICE}")
|
| 80 |
except Exception as e:
|
| 81 |
print(f"Error loading Qwen model: {e}")
|
|
|
|
| 82 |
model = None
|
| 83 |
processor = None
|
| 84 |
|
|
|
|
| 130 |
print(f"SAM prediction error: {e}")
|
| 131 |
return np.zeros((image.height, image.width), dtype=bool), 0.0
|
| 132 |
|
|
|
|
| 133 |
sam_wrapper = CustomSAMWrapper(SAM_CHECKPOINT, device=DEVICE)
|
| 134 |
|
|
|
|
| 135 |
|
| 136 |
def parse_custom_format(content: str):
|
| 137 |
point_pattern = r"<points>\s*(\[\s*(?:\[\s*\d+\s*,\s*\d+\s*\]\s*,?\s*)+\])\s*</points>"
|
|
|
|
| 314 |
print(f"Pipeline error: {e}")
|
| 315 |
return f"Error processing request: {str(e)}", None
|
| 316 |
|
|
|
|
| 317 |
|
| 318 |
with gr.Blocks(title="Seg-R1") as demo:
|
| 319 |
gr.Markdown("# Seg-R1: Visual Segmentation Assistant")
|