MediaStreamAI commited on
Commit
8c2d15a
Β·
verified Β·
1 Parent(s): d86a206

Upload inference.py (chunk 450 W2.7)

Browse files
Files changed (1) hide show
  1. inference.py +197 -0
inference.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MOTHER CORE V2 β€” chunk 450 (W2.7) β€” Reference Inference
4
+ ========================================================
5
+ Sovereign UK AI by MediaStream AI Limited (MSAI).
6
+
7
+ This script loads chunk 450 from HuggingFace and runs the LOCKED inference
8
+ rules used during training. Deviation from these rules produces incorrect
9
+ or degenerate output.
10
+
11
+ Usage:
12
+ python inference.py "What is the capital of Scotland?"
13
+ python inference.py # enters interactive mode
14
+
15
+ Requirements:
16
+ pip install torch safetensors sentencepiece huggingface_hub
17
+ """
18
+ from __future__ import annotations
19
+ import sys
20
+ import json
21
+ import torch
22
+ from pathlib import Path
23
+ from safetensors.torch import load_file
24
+ import sentencepiece as spm
25
+
26
+ # ════════════════════════════════════════════════════════════════════
27
+ # LOCKED INFERENCE RULES (DO NOT CHANGE)
28
+ # ════════════════════════════════════════════════════════════════════
29
+ BOS_ID = 1
30
+ EOS_ID = 2
31
+ PAD_ID = 0
32
+ PROMPT_FORMAT = "Question:\n\n{q}\n\nAnswer:"
33
+ REP_PEN = 1.3
34
+ NO_REPEAT_NGRAM = 4
35
+ MAX_NEW = 200
36
+ # Greedy argmax β€” no temperature, no sampling
37
+
38
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ DTYPE = torch.bfloat16
40
+
41
+
42
+ def load_model_and_tokenizer(repo_dir: str):
43
+ """Load MOTHER CORE from a local directory (downloaded HF snapshot)."""
44
+ repo = Path(repo_dir)
45
+
46
+ # Load config
47
+ with open(repo / "config.json") as f:
48
+ cfg = json.load(f)
49
+ print(f"Loaded config: {cfg['n_layers']} layers, dim={cfg['dim']}, "
50
+ f"params~{cfg.get('_msai_total_params_b', '?')}B")
51
+
52
+ # Load tokenizer (SentencePiece)
53
+ tokenizer = spm.SentencePieceProcessor()
54
+ tokenizer.Load(str(repo / "tokenizer.model"))
55
+ print(f"Loaded tokenizer: vocab_size={tokenizer.vocab_size()}")
56
+
57
+ # Build model β€” requires mother_core package available
58
+ try:
59
+ sys.path.insert(0, str(Path.home() / "mother-core-reasoning"))
60
+ from mother_core.config import ModelConfig
61
+ from mother_core.model import MotherCoreModel
62
+ except ImportError:
63
+ print("ERROR: mother_core package not found.")
64
+ print("This script requires the mother_core source code to be available.")
65
+ print("Either clone the MSAI sovereign training repo, or copy "
66
+ "mother_core/ into your PYTHONPATH.")
67
+ sys.exit(1)
68
+
69
+ config = ModelConfig(
70
+ vocab_size=cfg["vocab_size"],
71
+ dim=cfg["dim"],
72
+ n_layers=cfg["n_layers"],
73
+ n_heads=cfg["n_heads"],
74
+ n_kv_heads=cfg["n_kv_heads"],
75
+ ff_mult=cfg["ff_mult"],
76
+ max_seq_len=cfg["max_seq_len"],
77
+ rope_theta=cfg["rope_theta"],
78
+ rms_norm_eps=cfg["rms_norm_eps"],
79
+ )
80
+ model = MotherCoreModel(config)
81
+
82
+ # Load sharded safetensors
83
+ index_path = repo / "model.safetensors.index.json"
84
+ if index_path.exists():
85
+ with open(index_path) as f:
86
+ index = json.load(f)
87
+ shard_files = sorted(set(index["weight_map"].values()))
88
+ print(f"Loading {len(shard_files)} shards...")
89
+ full_sd = {}
90
+ for sf in shard_files:
91
+ print(f" - {sf}")
92
+ full_sd.update(load_file(str(repo / sf)))
93
+ model.load_state_dict(full_sd, strict=False)
94
+ else:
95
+ # Single-file fallback
96
+ sd = load_file(str(repo / "model.safetensors"))
97
+ model.load_state_dict(sd, strict=False)
98
+
99
+ model = model.to(DTYPE).to(DEVICE).eval()
100
+ print(f"Model on {DEVICE} in {DTYPE}")
101
+ return model, tokenizer
102
+
103
+
104
+ @torch.no_grad()
105
+ def generate_greedy(model, tokenizer, question: str,
106
+ max_new: int = MAX_NEW,
107
+ rep_pen: float = REP_PEN,
108
+ no_repeat_ngram: int = NO_REPEAT_NGRAM) -> str:
109
+ """
110
+ LOCKED inference path. Greedy argmax with n-gram blocking and
111
+ frequency-scaled repetition penalty.
112
+ """
113
+ prompt = PROMPT_FORMAT.format(q=question)
114
+ ids = [BOS_ID] + tokenizer.EncodeAsIds(prompt)
115
+ inp = torch.tensor([ids], device=DEVICE)
116
+ gen_out = []
117
+
118
+ for i in range(max_new):
119
+ x = inp if i == 0 else torch.tensor([[gen_out[-1]]], device=DEVICE)
120
+ out = model(x)
121
+ logits = out["logits"][:, -1, :].float()
122
+
123
+ # Block BOS in generated output, allow EOS only after at least 1 token
124
+ if len(gen_out) < 1:
125
+ logits[0, EOS_ID] = -1e9
126
+ logits[0, BOS_ID] = -1e9
127
+
128
+ # Frequency-scaled repetition penalty (only tokens seen β‰₯ 2 times)
129
+ if len(gen_out) >= 3:
130
+ from collections import Counter
131
+ counts = Counter(gen_out)
132
+ for t, c in counts.items():
133
+ if c >= 2 and 0 <= t < logits.shape[-1]:
134
+ logits[0, t] /= (rep_pen ** (c - 1))
135
+
136
+ # n-gram blocking
137
+ if no_repeat_ngram > 0 and len(gen_out) >= no_repeat_ngram:
138
+ ngram = tuple(gen_out[-(no_repeat_ngram - 1):]) if no_repeat_ngram > 1 else ()
139
+ banned = set()
140
+ for j in range(len(gen_out) - no_repeat_ngram + 1):
141
+ if tuple(gen_out[j:j + no_repeat_ngram - 1]) == ngram:
142
+ banned.add(gen_out[j + no_repeat_ngram - 1])
143
+ for t in banned:
144
+ if 0 <= t < logits.shape[-1]:
145
+ logits[0, t] = -1e9
146
+
147
+ # Greedy argmax (no temperature, no sampling)
148
+ nxt = logits.argmax(-1).item()
149
+
150
+ if nxt == EOS_ID:
151
+ break
152
+ gen_out.append(nxt)
153
+
154
+ # Cycle-break: 4 identical tokens in a row
155
+ if len(gen_out) >= 4 and len(set(gen_out[-4:])) == 1:
156
+ break
157
+
158
+ return tokenizer.DecodeIds(gen_out).strip()
159
+
160
+
161
+ def main():
162
+ # Download from HF if needed
163
+ try:
164
+ from huggingface_hub import snapshot_download
165
+ except ImportError:
166
+ print("ERROR: pip install huggingface_hub")
167
+ sys.exit(1)
168
+
169
+ print("Downloading MediaStreamAI/MOTHER_CORE_V2 ...")
170
+ repo_dir = snapshot_download(repo_id="MediaStreamAI/MOTHER_CORE_V2")
171
+ print(f"Local snapshot: {repo_dir}")
172
+ model, tokenizer = load_model_and_tokenizer(repo_dir)
173
+
174
+ if len(sys.argv) > 1:
175
+ question = " ".join(sys.argv[1:])
176
+ print(f"\nQ: {question}")
177
+ ans = generate_greedy(model, tokenizer, question)
178
+ print(f"A: {ans}")
179
+ return
180
+
181
+ print("\nInteractive mode. Type 'quit' to exit.\n")
182
+ while True:
183
+ try:
184
+ q = input("Q: ").strip()
185
+ except (EOFError, KeyboardInterrupt):
186
+ print()
187
+ break
188
+ if q.lower() in ("quit", "exit"):
189
+ break
190
+ if not q:
191
+ continue
192
+ ans = generate_greedy(model, tokenizer, q)
193
+ print(f"A: {ans}\n")
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()