geshang commited on
Commit
24a032c
·
verified ·
1 Parent(s): cb86726

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -21
app.py CHANGED
@@ -14,8 +14,7 @@ import cv2
14
  from typing import List, Tuple, Optional
15
  import sys
16
 
17
- # ------------------ 安装SAM2依赖 ------------------
18
- # ------------------ 修复SAM2模块路径问题 ------------------
19
  def add_sam2_to_path():
20
  """将SAM2安装目录添加到Python路径"""
21
  sam2_dir = os.path.abspath("third_party/sam2")
@@ -23,7 +22,6 @@ def add_sam2_to_path():
23
  sys.path.insert(0, sam2_dir)
24
  return sam2_dir
25
 
26
- # ------------------ 安装SAM2依赖 ------------------
27
  def install_sam2():
28
  """检查并安装SAM2及其依赖"""
29
  sam2_dir = "third_party/sam2"
@@ -31,7 +29,6 @@ def install_sam2():
31
  print("Installing SAM2...")
32
  os.makedirs("third_party", exist_ok=True)
33
 
34
- # 克隆仓库(添加--recursive以防有子模块)
35
  subprocess.run([
36
  "git", "clone",
37
  "--recursive",
@@ -39,15 +36,10 @@ def install_sam2():
39
  sam2_dir
40
  ], check=True)
41
 
42
- # 切换到sam2目录安装
43
  original_dir = os.getcwd()
44
  try:
45
  os.chdir(sam2_dir)
46
-
47
- # 先安装核心依赖
48
- # subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True)
49
-
50
- # 以可编辑模式安装SAM2
51
  subprocess.run(["pip", "install", "-e", "."], check=True)
52
 
53
 
@@ -62,27 +54,21 @@ def install_sam2():
62
  print("SAM2 already exists, skipping installation.")
63
 
64
 
65
- # 1. 安装SAM2
66
  install_sam2()
67
 
68
- # 2. 确保路径正确
69
  sam2_dir = add_sam2_to_path()
70
 
71
- # 3. 现在可以安全导入SAM2模块
72
  from sam2.build_sam import build_sam2
73
  from sam2.sam2_image_predictor import SAM2ImagePredictor
74
  print("🎉 SAM2 modules imported successfully!")
75
 
76
- # ------------------ 初始化模型 ------------------
77
- # 使用相对路径
78
  MODEL_PATH = "geshang/Seg-R1-COD"
79
  SAM_CHECKPOINT = "sam2_weights/sam2.1_hiera_large.pt"
80
 
81
- # 自动检测设备
82
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
83
  RESIZE_SIZE = (768, 768)
84
 
85
- # 加载Qwen模型
86
  try:
87
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
88
  MODEL_PATH,
@@ -93,7 +79,6 @@ try:
93
  print(f"Qwen model loaded on {DEVICE}")
94
  except Exception as e:
95
  print(f"Error loading Qwen model: {e}")
96
- # 创建虚拟模型以便继续运行
97
  model = None
98
  processor = None
99
 
@@ -145,10 +130,8 @@ class CustomSAMWrapper:
145
  print(f"SAM prediction error: {e}")
146
  return np.zeros((image.height, image.width), dtype=bool), 0.0
147
 
148
- # 初始化SAM包装器
149
  sam_wrapper = CustomSAMWrapper(SAM_CHECKPOINT, device=DEVICE)
150
 
151
- # ------------------ 推理相关函数 ------------------
152
 
153
  def parse_custom_format(content: str):
154
  point_pattern = r"<points>\s*(\[\s*(?:\[\s*\d+\s*,\s*\d+\s*\]\s*,?\s*)+\])\s*</points>"
@@ -331,7 +314,6 @@ def run_pipeline(image: PILImage.Image, prompt: str):
331
  print(f"Pipeline error: {e}")
332
  return f"Error processing request: {str(e)}", None
333
 
334
- # ------------------ 启动 Gradio ------------------
335
 
336
  with gr.Blocks(title="Seg-R1") as demo:
337
  gr.Markdown("# Seg-R1: Visual Segmentation Assistant")
 
14
  from typing import List, Tuple, Optional
15
  import sys
16
 
17
+
 
18
  def add_sam2_to_path():
19
  """将SAM2安装目录添加到Python路径"""
20
  sam2_dir = os.path.abspath("third_party/sam2")
 
22
  sys.path.insert(0, sam2_dir)
23
  return sam2_dir
24
 
 
25
  def install_sam2():
26
  """检查并安装SAM2及其依赖"""
27
  sam2_dir = "third_party/sam2"
 
29
  print("Installing SAM2...")
30
  os.makedirs("third_party", exist_ok=True)
31
 
 
32
  subprocess.run([
33
  "git", "clone",
34
  "--recursive",
 
36
  sam2_dir
37
  ], check=True)
38
 
 
39
  original_dir = os.getcwd()
40
  try:
41
  os.chdir(sam2_dir)
42
+
 
 
 
 
43
  subprocess.run(["pip", "install", "-e", "."], check=True)
44
 
45
 
 
54
  print("SAM2 already exists, skipping installation.")
55
 
56
 
 
57
  install_sam2()
58
 
 
59
  sam2_dir = add_sam2_to_path()
60
 
 
61
  from sam2.build_sam import build_sam2
62
  from sam2.sam2_image_predictor import SAM2ImagePredictor
63
  print("🎉 SAM2 modules imported successfully!")
64
 
65
+
 
66
  MODEL_PATH = "geshang/Seg-R1-COD"
67
  SAM_CHECKPOINT = "sam2_weights/sam2.1_hiera_large.pt"
68
 
 
69
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
70
  RESIZE_SIZE = (768, 768)
71
 
 
72
  try:
73
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
74
  MODEL_PATH,
 
79
  print(f"Qwen model loaded on {DEVICE}")
80
  except Exception as e:
81
  print(f"Error loading Qwen model: {e}")
 
82
  model = None
83
  processor = None
84
 
 
130
  print(f"SAM prediction error: {e}")
131
  return np.zeros((image.height, image.width), dtype=bool), 0.0
132
 
 
133
  sam_wrapper = CustomSAMWrapper(SAM_CHECKPOINT, device=DEVICE)
134
 
 
135
 
136
  def parse_custom_format(content: str):
137
  point_pattern = r"<points>\s*(\[\s*(?:\[\s*\d+\s*,\s*\d+\s*\]\s*,?\s*)+\])\s*</points>"
 
314
  print(f"Pipeline error: {e}")
315
  return f"Error processing request: {str(e)}", None
316
 
 
317
 
318
  with gr.Blocks(title="Seg-R1") as demo:
319
  gr.Markdown("# Seg-R1: Visual Segmentation Assistant")