File size: 6,879 Bytes
8c2d15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#!/usr/bin/env python3
"""
MOTHER CORE V2 β€” chunk 450 (W2.7) β€” Reference Inference
========================================================
Sovereign UK AI by MediaStream AI Limited (MSAI).

This script loads chunk 450 from HuggingFace and runs the LOCKED inference
rules used during training. Deviation from these rules produces incorrect
or degenerate output.

Usage:
    python inference.py "What is the capital of Scotland?"
    python inference.py  # enters interactive mode

Requirements:
    pip install torch safetensors sentencepiece huggingface_hub
"""
from __future__ import annotations
import sys
import json
import torch
from pathlib import Path
from safetensors.torch import load_file
import sentencepiece as spm

# ════════════════════════════════════════════════════════════════════
# LOCKED INFERENCE RULES (DO NOT CHANGE)
# ════════════════════════════════════════════════════════════════════
BOS_ID = 1
EOS_ID = 2
PAD_ID = 0
PROMPT_FORMAT = "Question:\n\n{q}\n\nAnswer:"
REP_PEN = 1.3
NO_REPEAT_NGRAM = 4
MAX_NEW = 200
# Greedy argmax β€” no temperature, no sampling

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16


def load_model_and_tokenizer(repo_dir: str):
    """Load MOTHER CORE from a local directory (downloaded HF snapshot)."""
    repo = Path(repo_dir)

    # Load config
    with open(repo / "config.json") as f:
        cfg = json.load(f)
    print(f"Loaded config: {cfg['n_layers']} layers, dim={cfg['dim']}, "
          f"params~{cfg.get('_msai_total_params_b', '?')}B")

    # Load tokenizer (SentencePiece)
    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(str(repo / "tokenizer.model"))
    print(f"Loaded tokenizer: vocab_size={tokenizer.vocab_size()}")

    # Build model β€” requires mother_core package available
    try:
        sys.path.insert(0, str(Path.home() / "mother-core-reasoning"))
        from mother_core.config import ModelConfig
        from mother_core.model import MotherCoreModel
    except ImportError:
        print("ERROR: mother_core package not found.")
        print("This script requires the mother_core source code to be available.")
        print("Either clone the MSAI sovereign training repo, or copy "
              "mother_core/ into your PYTHONPATH.")
        sys.exit(1)

    config = ModelConfig(
        vocab_size=cfg["vocab_size"],
        dim=cfg["dim"],
        n_layers=cfg["n_layers"],
        n_heads=cfg["n_heads"],
        n_kv_heads=cfg["n_kv_heads"],
        ff_mult=cfg["ff_mult"],
        max_seq_len=cfg["max_seq_len"],
        rope_theta=cfg["rope_theta"],
        rms_norm_eps=cfg["rms_norm_eps"],
    )
    model = MotherCoreModel(config)

    # Load sharded safetensors
    index_path = repo / "model.safetensors.index.json"
    if index_path.exists():
        with open(index_path) as f:
            index = json.load(f)
        shard_files = sorted(set(index["weight_map"].values()))
        print(f"Loading {len(shard_files)} shards...")
        full_sd = {}
        for sf in shard_files:
            print(f"  - {sf}")
            full_sd.update(load_file(str(repo / sf)))
        model.load_state_dict(full_sd, strict=False)
    else:
        # Single-file fallback
        sd = load_file(str(repo / "model.safetensors"))
        model.load_state_dict(sd, strict=False)

    model = model.to(DTYPE).to(DEVICE).eval()
    print(f"Model on {DEVICE} in {DTYPE}")
    return model, tokenizer


@torch.no_grad()
def generate_greedy(model, tokenizer, question: str,
                    max_new: int = MAX_NEW,
                    rep_pen: float = REP_PEN,
                    no_repeat_ngram: int = NO_REPEAT_NGRAM) -> str:
    """
    LOCKED inference path. Greedy argmax with n-gram blocking and
    frequency-scaled repetition penalty.
    """
    prompt = PROMPT_FORMAT.format(q=question)
    ids = [BOS_ID] + tokenizer.EncodeAsIds(prompt)
    inp = torch.tensor([ids], device=DEVICE)
    gen_out = []

    for i in range(max_new):
        x = inp if i == 0 else torch.tensor([[gen_out[-1]]], device=DEVICE)
        out = model(x)
        logits = out["logits"][:, -1, :].float()

        # Block BOS in generated output, allow EOS only after at least 1 token
        if len(gen_out) < 1:
            logits[0, EOS_ID] = -1e9
        logits[0, BOS_ID] = -1e9

        # Frequency-scaled repetition penalty (only tokens seen β‰₯ 2 times)
        if len(gen_out) >= 3:
            from collections import Counter
            counts = Counter(gen_out)
            for t, c in counts.items():
                if c >= 2 and 0 <= t < logits.shape[-1]:
                    logits[0, t] /= (rep_pen ** (c - 1))

        # n-gram blocking
        if no_repeat_ngram > 0 and len(gen_out) >= no_repeat_ngram:
            ngram = tuple(gen_out[-(no_repeat_ngram - 1):]) if no_repeat_ngram > 1 else ()
            banned = set()
            for j in range(len(gen_out) - no_repeat_ngram + 1):
                if tuple(gen_out[j:j + no_repeat_ngram - 1]) == ngram:
                    banned.add(gen_out[j + no_repeat_ngram - 1])
            for t in banned:
                if 0 <= t < logits.shape[-1]:
                    logits[0, t] = -1e9

        # Greedy argmax (no temperature, no sampling)
        nxt = logits.argmax(-1).item()

        if nxt == EOS_ID:
            break
        gen_out.append(nxt)

        # Cycle-break: 4 identical tokens in a row
        if len(gen_out) >= 4 and len(set(gen_out[-4:])) == 1:
            break

    return tokenizer.DecodeIds(gen_out).strip()


def main():
    # Download from HF if needed
    try:
        from huggingface_hub import snapshot_download
    except ImportError:
        print("ERROR: pip install huggingface_hub")
        sys.exit(1)

    print("Downloading MediaStreamAI/MOTHER_CORE_V2 ...")
    repo_dir = snapshot_download(repo_id="MediaStreamAI/MOTHER_CORE_V2")
    print(f"Local snapshot: {repo_dir}")
    model, tokenizer = load_model_and_tokenizer(repo_dir)

    if len(sys.argv) > 1:
        question = " ".join(sys.argv[1:])
        print(f"\nQ: {question}")
        ans = generate_greedy(model, tokenizer, question)
        print(f"A: {ans}")
        return

    print("\nInteractive mode. Type 'quit' to exit.\n")
    while True:
        try:
            q = input("Q: ").strip()
        except (EOFError, KeyboardInterrupt):
            print()
            break
        if q.lower() in ("quit", "exit"):
            break
        if not q:
            continue
        ans = generate_greedy(model, tokenizer, q)
        print(f"A: {ans}\n")


if __name__ == "__main__":
    main()