Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -12,7 +12,8 @@
|
|
| 12 |
- Added: Pad to multiple of 16 to avoid transformer patch issues
|
| 13 |
- Added: Prefer fp16 over bf16 for Tesla T4 compatibility
|
| 14 |
- New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
|
| 15 |
-
- New: Full default cfg from official config.json to
|
|
|
|
| 16 |
"""
|
| 17 |
from __future__ import annotations
|
| 18 |
import os
|
|
@@ -443,14 +444,14 @@ def load(self) -> Optional[Any]:
|
|
| 443 |
"long_term": {
|
| 444 |
"buffer_tokens": 2000,
|
| 445 |
"count_usage": True,
|
| 446 |
-
"max_mem_frames":
|
| 447 |
"max_num_tokens": 10000,
|
| 448 |
"min_mem_frames": 5,
|
| 449 |
"num_prototypes": 128
|
| 450 |
},
|
| 451 |
"max_internal_size": -1,
|
| 452 |
-
"max_mem_frames":
|
| 453 |
-
"mem_every":
|
| 454 |
"model": {
|
| 455 |
"aux_loss": {
|
| 456 |
"query": {
|
|
@@ -532,10 +533,17 @@ def load(self) -> Optional[Any]:
|
|
| 532 |
cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
|
| 533 |
if isinstance(cfg, dict):
|
| 534 |
cfg = dict(cfg) # Copy to avoid modifying model.cfg
|
| 535 |
-
# Override specific values
|
| 536 |
overrides = {
|
| 537 |
'chunk_size': 1,
|
| 538 |
'flip_aug': False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
}
|
| 540 |
cfg.update(overrides)
|
| 541 |
# Convert to EasyDict for dot access
|
|
|
|
| 12 |
- Added: Pad to multiple of 16 to avoid transformer patch issues
|
| 13 |
- Added: Prefer fp16 over bf16 for Tesla T4 compatibility
|
| 14 |
- New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
|
| 15 |
+
- New: Full default cfg from official config.json to fix 'mem_every' issues
|
| 16 |
+
- Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
|
| 17 |
"""
|
| 18 |
from __future__ import annotations
|
| 19 |
import os
|
|
|
|
| 444 |
"long_term": {
|
| 445 |
"buffer_tokens": 2000,
|
| 446 |
"count_usage": True,
|
| 447 |
+
"max_mem_frames": 0, # Disable long-term memory
|
| 448 |
"max_num_tokens": 10000,
|
| 449 |
"min_mem_frames": 5,
|
| 450 |
"num_prototypes": 128
|
| 451 |
},
|
| 452 |
"max_internal_size": -1,
|
| 453 |
+
"max_mem_frames": 0, # Disable short-term memory
|
| 454 |
+
"mem_every": -1, # Disable memory updates
|
| 455 |
"model": {
|
| 456 |
"aux_loss": {
|
| 457 |
"query": {
|
|
|
|
| 533 |
cfg = getattr(self.model, "cfg", default_cfg) or default_cfg
|
| 534 |
if isinstance(cfg, dict):
|
| 535 |
cfg = dict(cfg) # Copy to avoid modifying model.cfg
|
| 536 |
+
# Override specific values to disable memory and potential dim issues
|
| 537 |
overrides = {
|
| 538 |
'chunk_size': 1,
|
| 539 |
'flip_aug': False,
|
| 540 |
+
'mem_every': -1,
|
| 541 |
+
'max_mem_frames': 0,
|
| 542 |
+
'use_long_term': False,
|
| 543 |
+
'long_term': {
|
| 544 |
+
'max_mem_frames': 0,
|
| 545 |
+
'min_mem_frames': 0,
|
| 546 |
+
},
|
| 547 |
}
|
| 548 |
cfg.update(overrides)
|
| 549 |
# Convert to EasyDict for dot access
|