MogensR commited on
Commit
3814cb0
·
verified ·
1 Parent(s): c91f93f

Update models/model_loaders.py

Browse files
Files changed (1) hide show
  1. models/model_loaders.py +16 -24
models/model_loaders.py CHANGED
@@ -11,21 +11,18 @@
11
  import torch
12
  import psutil
13
  from contextlib import contextmanager
14
-
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
-
18
  @contextmanager
19
  def torch_memory_manager():
20
  try:
21
- logger.info("[torch_memory_manager] Enter") # [LOG+SAFETY PATCH]
22
  yield
23
  finally:
24
  if torch.cuda.is_available():
25
  torch.cuda.empty_cache()
26
  gc.collect()
27
- logger.info("[torch_memory_manager] Exit, cleaned up") # [LOG+SAFETY PATCH]
28
-
29
  def get_memory_usage():
30
  memory_info = {}
31
  if torch.cuda.is_available():
@@ -35,22 +32,20 @@ def get_memory_usage():
35
  torch.cuda.memory_allocated()) / 1e9
36
  memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
37
  memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
38
- logger.info(f"[get_memory_usage] {memory_info}") # [LOG+SAFETY PATCH]
39
  return memory_info
40
-
41
  def clear_model_cache():
42
- logger.info("[clear_model_cache] Clearing all model caches...") # [LOG+SAFETY PATCH]
43
  if hasattr(st, 'cache_resource'):
44
  st.cache_resource.clear()
45
  if torch.cuda.is_available():
46
  torch.cuda.empty_cache()
47
  gc.collect()
48
- logger.info("[clear_model_cache] Model cache cleared") # [LOG+SAFETY PATCH]
49
-
50
  @st.cache_resource(show_spinner=False)
51
  def load_sam2_predictor():
52
  try:
53
- logger.info("[load_sam2_predictor] Loading SAM2 image predictor...") # [LOG+SAFETY PATCH]
54
  from sam2.build_sam import build_sam2
55
  from sam2.sam2_image_predictor import SAM2ImagePredictor
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -94,19 +89,21 @@ def load_sam2_predictor():
94
  import traceback
95
  traceback.print_exc()
96
  return None, None
97
-
98
  def load_sam2():
99
  predictor, device = load_sam2_predictor()
100
  return predictor
101
-
102
  @st.cache_resource(show_spinner=False)
103
  def load_matanyone_processor():
104
  try:
105
- logger.info("[load_matanyone_processor] Loading MatAnyone processor...") # [LOG+SAFETY PATCH]
106
  from matanyone import InferenceCore
107
  device = "cuda" if torch.cuda.is_available() else "cpu"
108
  logger.info(f"[load_matanyone_processor] MatAnyone using device: {device}")
109
- processor = InferenceCore("PeiqingYang/MatAnyone", device=device)
 
 
 
 
110
  if hasattr(processor, 'model'):
111
  processor.model.to(device)
112
  processor.model.eval()
@@ -121,11 +118,9 @@ def load_matanyone_processor():
121
  import traceback
122
  traceback.print_exc()
123
  return None, None
124
-
125
  def load_matanyone():
126
  processor, device = load_matanyone_processor()
127
  return processor
128
-
129
  def test_models():
130
  results = {
131
  'sam2': {'loaded': False, 'error': None, 'device': None},
@@ -151,9 +146,8 @@ def test_models():
151
  except Exception as e:
152
  results['matanyone']['error'] = str(e)
153
  logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True)
154
- logger.info(f"[test_models] Results: {results}") # [LOG+SAFETY PATCH]
155
  return results
156
-
157
  def log_memory_usage(stage=""):
158
  memory_info = get_memory_usage()
159
  log_msg = f"Memory usage"
@@ -166,17 +160,15 @@ def log_memory_usage(stage=""):
166
  print(log_msg, flush=True)
167
  logger.info(log_msg)
168
  return memory_info
169
-
170
  def check_memory_available(required_gb=2.0):
171
  if not torch.cuda.is_available():
172
  return False, 0.0
173
  memory_info = get_memory_usage()
174
  free_gb = memory_info.get('gpu_free', 0)
175
- logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}") # [LOG+SAFETY PATCH]
176
  return free_gb >= required_gb, free_gb
177
-
178
  def free_memory_aggressive():
179
- logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...") # [LOG+SAFETY PATCH]
180
  print("Performing aggressive memory cleanup...", flush=True)
181
  clear_model_cache()
182
  if torch.cuda.is_available():
@@ -189,4 +181,4 @@ def free_memory_aggressive():
189
  gc.collect()
190
  print("Memory cleanup complete", flush=True)
191
  logger.info("Memory cleanup complete")
192
- log_memory_usage("after cleanup")
 
11
  import torch
12
  import psutil
13
  from contextlib import contextmanager
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
 
16
  @contextmanager
17
  def torch_memory_manager():
18
  try:
19
+ logger.info("[torch_memory_manager] Enter")
20
  yield
21
  finally:
22
  if torch.cuda.is_available():
23
  torch.cuda.empty_cache()
24
  gc.collect()
25
+ logger.info("[torch_memory_manager] Exit, cleaned up")
 
26
  def get_memory_usage():
27
  memory_info = {}
28
  if torch.cuda.is_available():
 
32
  torch.cuda.memory_allocated()) / 1e9
33
  memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
34
  memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
35
+ logger.info(f"[get_memory_usage] {memory_info}")
36
  return memory_info
 
37
  def clear_model_cache():
38
+ logger.info("[clear_model_cache] Clearing all model caches...")
39
  if hasattr(st, 'cache_resource'):
40
  st.cache_resource.clear()
41
  if torch.cuda.is_available():
42
  torch.cuda.empty_cache()
43
  gc.collect()
44
+ logger.info("[clear_model_cache] Model cache cleared")
 
45
  @st.cache_resource(show_spinner=False)
46
  def load_sam2_predictor():
47
  try:
48
+ logger.info("[load_sam2_predictor] Loading SAM2 image predictor...")
49
  from sam2.build_sam import build_sam2
50
  from sam2.sam2_image_predictor import SAM2ImagePredictor
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
89
  import traceback
90
  traceback.print_exc()
91
  return None, None
 
92
  def load_sam2():
93
  predictor, device = load_sam2_predictor()
94
  return predictor
 
95
  @st.cache_resource(show_spinner=False)
96
  def load_matanyone_processor():
97
  try:
98
+ logger.info("[load_matanyone_processor] Loading MatAnyone processor...")
99
  from matanyone import InferenceCore
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
  logger.info(f"[load_matanyone_processor] MatAnyone using device: {device}")
102
+ try:
103
+ processor = InferenceCore("PeiqingYang/MatAnyone", device=device)
104
+ except Exception as e:
105
+ logger.warning(f"[load_matanyone_processor] Path warning caught: {e}")
106
+ processor = InferenceCore("PeiqingYang/MatAnyone", device=device) # Retry
107
  if hasattr(processor, 'model'):
108
  processor.model.to(device)
109
  processor.model.eval()
 
118
  import traceback
119
  traceback.print_exc()
120
  return None, None
 
121
  def load_matanyone():
122
  processor, device = load_matanyone_processor()
123
  return processor
 
124
  def test_models():
125
  results = {
126
  'sam2': {'loaded': False, 'error': None, 'device': None},
 
146
  except Exception as e:
147
  results['matanyone']['error'] = str(e)
148
  logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True)
149
+ logger.info(f"[test_models] Results: {results}")
150
  return results
 
151
  def log_memory_usage(stage=""):
152
  memory_info = get_memory_usage()
153
  log_msg = f"Memory usage"
 
160
  print(log_msg, flush=True)
161
  logger.info(log_msg)
162
  return memory_info
 
163
  def check_memory_available(required_gb=2.0):
164
  if not torch.cuda.is_available():
165
  return False, 0.0
166
  memory_info = get_memory_usage()
167
  free_gb = memory_info.get('gpu_free', 0)
168
+ logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}")
169
  return free_gb >= required_gb, free_gb
 
170
  def free_memory_aggressive():
171
+ logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...")
172
  print("Performing aggressive memory cleanup...", flush=True)
173
  clear_model_cache()
174
  if torch.cuda.is_available():
 
181
  gc.collect()
182
  print("Memory cleanup complete", flush=True)
183
  logger.info("Memory cleanup complete")
184
+ log_memory_usage("after cleanup")