MogensR commited on
Commit
eee126e
·
verified ·
1 Parent(s): 77bba4f

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +212 -0
models.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Model Loading and Memory Management
4
+ Handles lazy loading of SAM2 and MatAnyone models with caching
5
+ """
6
+
7
+ import os
8
+ import gc
9
+ import logging
10
+ import streamlit as st
11
+ import torch
12
+ import psutil
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Context manager for CUDA memory cleanup
17
+ from contextlib import contextmanager
18
+
19
+ @contextmanager
20
+ def torch_memory_manager():
21
+ """Context manager for CUDA memory cleanup."""
22
+ try:
23
+ yield
24
+ finally:
25
+ if torch.cuda.is_available():
26
+ torch.cuda.empty_cache()
27
+ gc.collect()
28
+
29
+ def get_memory_usage():
30
+ """Get current memory usage statistics."""
31
+ memory_info = {}
32
+
33
+ # GPU memory if available
34
+ if torch.cuda.is_available():
35
+ memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9
36
+ memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9
37
+ memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory -
38
+ torch.cuda.memory_allocated()) / 1e9
39
+
40
+ # RAM memory
41
+ memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
42
+ memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
43
+
44
+ return memory_info
45
+
46
+ def clear_model_cache():
47
+ """Clear all cached models and free memory."""
48
+ if hasattr(st, 'cache_resource'):
49
+ st.cache_resource.clear()
50
+ if torch.cuda.is_available():
51
+ torch.cuda.empty_cache()
52
+ gc.collect()
53
+ logger.info("Model cache cleared")
54
+
55
+ # ============================================================================
56
+ # SAM2 Model Loading
57
+ # ============================================================================
58
+
59
+ @st.cache_resource(show_spinner=False)
60
+ def load_sam2_predictor():
61
+ """
62
+ Lazy load SAM2 image predictor with fallback strategies.
63
+ Returns None if loading fails.
64
+ """
65
+ try:
66
+ logger.info("Loading SAM2 image predictor...")
67
+ from sam2.build_sam import build_sam2
68
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
69
+
70
+ # Try local checkpoints first
71
+ checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt"
72
+ model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml"
73
+
74
+ if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg):
75
+ logger.warning("Local checkpoints not found, using Hugging Face...")
76
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
77
+ else:
78
+ # Check available GPU memory
79
+ memory_info = get_memory_usage()
80
+ gpu_free = memory_info.get('gpu_free', 0)
81
+
82
+ if gpu_free < 4.0:
83
+ logger.warning(f"Limited GPU memory ({gpu_free:.1f}GB), using smaller SAM2 model...")
84
+ try:
85
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
86
+ except Exception:
87
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
88
+ else:
89
+ # Use local large model
90
+ predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint_path))
91
+
92
+ logger.info("✅ SAM2 image predictor loaded successfully!")
93
+ return predictor
94
+
95
+ except Exception as e:
96
+ logger.error(f"Failed to load SAM2 predictor: {e}")
97
+ return None
98
+
99
+ # ============================================================================
100
+ # MatAnyone Model Loading
101
+ # ============================================================================
102
+
103
+ @st.cache_resource(show_spinner=False)
104
+ def load_matanyone_processor():
105
+ """
106
+ Lazy load MatAnyone processor.
107
+ Returns None if loading fails.
108
+ """
109
+ try:
110
+ logger.info("Loading MatAnyone processor...")
111
+ from matanyone import InferenceCore
112
+
113
+ processor = InferenceCore("PeiqingYang/MatAnyone")
114
+
115
+ logger.info("✅ MatAnyone processor loaded successfully!")
116
+ return processor
117
+
118
+ except Exception as e:
119
+ logger.error(f"Failed to load MatAnyone: {e}")
120
+ return None
121
+
122
+ # ============================================================================
123
+ # Model Health Check
124
+ # ============================================================================
125
+
126
+ def test_models():
127
+ """
128
+ Test if both models can load successfully.
129
+ Returns dict with test results.
130
+ """
131
+ results = {
132
+ 'sam2': {'loaded': False, 'error': None},
133
+ 'matanyone': {'loaded': False, 'error': None}
134
+ }
135
+
136
+ # Test SAM2
137
+ try:
138
+ sam2_predictor = load_sam2_predictor()
139
+ if sam2_predictor is not None:
140
+ results['sam2']['loaded'] = True
141
+ else:
142
+ results['sam2']['error'] = "Predictor returned None"
143
+ except Exception as e:
144
+ results['sam2']['error'] = str(e)
145
+
146
+ # Test MatAnyone
147
+ try:
148
+ matanyone_processor = load_matanyone_processor()
149
+ if matanyone_processor is not None:
150
+ results['matanyone']['loaded'] = True
151
+ else:
152
+ results['matanyone']['error'] = "Processor returned None"
153
+ except Exception as e:
154
+ results['matanyone']['error'] = str(e)
155
+
156
+ return results
157
+
158
+ # ============================================================================
159
+ # Memory Monitoring
160
+ # ============================================================================
161
+
162
+ def log_memory_usage(stage=""):
163
+ """Log current memory usage with optional stage label."""
164
+ memory_info = get_memory_usage()
165
+
166
+ log_msg = f"Memory usage"
167
+ if stage:
168
+ log_msg += f" ({stage})"
169
+ log_msg += ":"
170
+
171
+ if 'gpu_allocated' in memory_info:
172
+ log_msg += f" GPU {memory_info['gpu_allocated']:.1f}GB allocated, {memory_info['gpu_free']:.1f}GB free"
173
+
174
+ log_msg += f" | RAM {memory_info['ram_used']:.1f}GB used"
175
+
176
+ logger.info(log_msg)
177
+ return memory_info
178
+
179
+ def check_memory_available(required_gb=2.0):
180
+ """
181
+ Check if enough GPU memory is available.
182
+ Returns (bool, float) - (is_available, free_gb)
183
+ """
184
+ if not torch.cuda.is_available():
185
+ return False, 0.0
186
+
187
+ memory_info = get_memory_usage()
188
+ free_gb = memory_info.get('gpu_free', 0)
189
+
190
+ return free_gb >= required_gb, free_gb
191
+
192
+ def free_memory_aggressive():
193
+ """Aggressively free GPU and system memory."""
194
+ logger.info("Performing aggressive memory cleanup...")
195
+
196
+ # Clear model cache
197
+ clear_model_cache()
198
+
199
+ # CUDA cleanup
200
+ if torch.cuda.is_available():
201
+ torch.cuda.empty_cache()
202
+ torch.cuda.synchronize()
203
+ try:
204
+ torch.cuda.ipc_collect()
205
+ except Exception:
206
+ pass
207
+
208
+ # System cleanup
209
+ gc.collect()
210
+
211
+ logger.info("Memory cleanup complete")
212
+ log_memory_usage("after cleanup")