Update models/loaders/matanyone_loader.py
Browse files
models/loaders/matanyone_loader.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
- New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
|
| 15 |
- New: Full default cfg from official config.json to fix 'mem_every' issues
|
| 16 |
- Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
|
|
|
|
| 17 |
"""
|
| 18 |
from __future__ import annotations
|
| 19 |
import os
|
|
@@ -444,14 +445,14 @@ def load(self) -> Optional[Any]:
|
|
| 444 |
"long_term": {
|
| 445 |
"buffer_tokens": 2000,
|
| 446 |
"count_usage": True,
|
| 447 |
-
"max_mem_frames":
|
| 448 |
"max_num_tokens": 10000,
|
| 449 |
"min_mem_frames": 5,
|
| 450 |
"num_prototypes": 128
|
| 451 |
},
|
| 452 |
"max_internal_size": -1,
|
| 453 |
-
"max_mem_frames":
|
| 454 |
-
"mem_every":
|
| 455 |
"model": {
|
| 456 |
"aux_loss": {
|
| 457 |
"query": {
|
|
@@ -540,12 +541,24 @@ def load(self) -> Optional[Any]:
|
|
| 540 |
'mem_every': -1,
|
| 541 |
'max_mem_frames': 0,
|
| 542 |
'use_long_term': False,
|
| 543 |
-
'long_term': {
|
| 544 |
-
'max_mem_frames': 0,
|
| 545 |
-
'min_mem_frames': 0,
|
| 546 |
-
},
|
| 547 |
}
|
| 548 |
cfg.update(overrides)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
# Convert to EasyDict for dot access
|
| 550 |
cfg = EasyDict(cfg)
|
| 551 |
# Inference core
|
|
|
|
| 14 |
- New: EasyDict polyfill and conversion for cfg to fix 'dict no attribute' errors
|
| 15 |
- New: Full default cfg from official config.json to fix 'mem_every' issues
|
| 16 |
- Update: Disable memory propagation by setting mem_every=-1, max_mem_frames=0 to fix dim mismatch in fusion
|
| 17 |
+
- Fix: Merge long_term overrides to preserve keys like count_usage
|
| 18 |
"""
|
| 19 |
from __future__ import annotations
|
| 20 |
import os
|
|
|
|
| 445 |
"long_term": {
|
| 446 |
"buffer_tokens": 2000,
|
| 447 |
"count_usage": True,
|
| 448 |
+
"max_mem_frames": 10,
|
| 449 |
"max_num_tokens": 10000,
|
| 450 |
"min_mem_frames": 5,
|
| 451 |
"num_prototypes": 128
|
| 452 |
},
|
| 453 |
"max_internal_size": -1,
|
| 454 |
+
"max_mem_frames": 5,
|
| 455 |
+
"mem_every": 5,
|
| 456 |
"model": {
|
| 457 |
"aux_loss": {
|
| 458 |
"query": {
|
|
|
|
| 541 |
'mem_every': -1,
|
| 542 |
'max_mem_frames': 0,
|
| 543 |
'use_long_term': False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
}
|
| 545 |
cfg.update(overrides)
|
| 546 |
+
# Merge long_term overrides without overwriting whole dict
|
| 547 |
+
if 'long_term' in cfg:
|
| 548 |
+
cfg['long_term'].update({
|
| 549 |
+
'max_mem_frames': 0,
|
| 550 |
+
'min_mem_frames': 0,
|
| 551 |
+
'count_usage': False, # Disable since memory off
|
| 552 |
+
})
|
| 553 |
+
else:
|
| 554 |
+
cfg['long_term'] = {
|
| 555 |
+
'buffer_tokens': 2000,
|
| 556 |
+
'count_usage': False,
|
| 557 |
+
'max_mem_frames': 0,
|
| 558 |
+
'max_num_tokens': 10000,
|
| 559 |
+
'min_mem_frames': 0,
|
| 560 |
+
'num_prototypes': 128
|
| 561 |
+
}
|
| 562 |
# Convert to EasyDict for dot access
|
| 563 |
cfg = EasyDict(cfg)
|
| 564 |
# Inference core
|