amewebstudio commited on
Commit
338eca0
·
verified ·
1 Parent(s): 67ae1c5

Upload configuration_nexus_worldmodel.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_nexus_worldmodel.py +70 -0
configuration_nexus_worldmodel.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NEXUS-WorldModel Configuration
3
+
4
+ Auto-generated configuration class for loading from HuggingFace Hub.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from dataclasses import dataclass, field
10
+ from typing import List, Dict, Any, Optional
11
+
12
+ @dataclass
13
+ class WorldConfig:
14
+ width: int = 64
15
+ height: int = 64
16
+ channels: int = 3
17
+ gravity: float = 0.1
18
+ friction: float = 0.98
19
+ bounce: float = 0.8
20
+ max_velocity: float = 5.0
21
+ max_agents: int = 5
22
+ max_obstacles: int = 10
23
+ max_zones: int = 3
24
+ agent_radius: float = 2.0
25
+ dt: float = 1.0
26
+
27
+ @dataclass
28
+ class NexusWorldModelConfig:
29
+ """Configuration for NEXUS-WorldModel v2.0"""
30
+
31
+ model_type: str = "nexus-worldmodel"
32
+ version: str = "2.0"
33
+
34
+ d_model: int = 512
35
+ d_ff: int = 2048
36
+ n_layers: int = 8
37
+ n_heads: int = 8
38
+ dropout: float = 0.1
39
+
40
+ latent_dim: int = 256
41
+ latent_state_dim: int = 256
42
+ action_dim: int = 5
43
+
44
+ use_lpol: bool = True
45
+ use_gqa: bool = True
46
+ gqa_num_heads: int = 8
47
+ gqa_num_kv_groups: int = 2
48
+
49
+ neurogenesis_enabled: bool = True
50
+ dream_enabled: bool = True
51
+
52
+ world: WorldConfig = field(default_factory=WorldConfig)
53
+
54
+ def to_dict(self) -> Dict[str, Any]:
55
+ return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
56
+
57
+ @classmethod
58
+ def from_pretrained(cls, path_or_repo: str, **kwargs):
59
+ """Load config from HuggingFace Hub or local path"""
60
+ if os.path.isdir(path_or_repo):
61
+ config_file = os.path.join(path_or_repo, "config.json")
62
+ else:
63
+ from huggingface_hub import hf_hub_download
64
+ config_file = hf_hub_download(repo_id=path_or_repo, filename="config.json", **kwargs)
65
+
66
+ with open(config_file, "r") as f:
67
+ config_dict = json.load(f)
68
+
69
+ return cls(**{k: v for k, v in config_dict.items()
70
+ if k in cls.__dataclass_fields__ and k != "world"})