File size: 3,633 Bytes
616bac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
KnowForge-0.6B inference — standalone, no extra deps beyond transformers.

CLI:   python inference.py "ZELPH RULES: ... Question: ..."
API:   from inference import ask; result = ask("ZELPH RULES: ...")
"""
import re
import sys
from pathlib import Path
from typing import Optional

_MODEL_DIR = Path(__file__).parent
_pipeline  = None  # lazy singleton


def _load_pipeline():
    global _pipeline
    if _pipeline is not None:
        return _pipeline
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

    tokenizer = AutoTokenizer.from_pretrained(str(_MODEL_DIR), trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        str(_MODEL_DIR),
        dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    _pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
    )
    return _pipeline


def _split_think(raw: str) -> tuple[str, str]:
    """Split <think>...</think> from final answer. Returns (answer, reasoning).

    Handles truncated output where </think> was never generated.
    """
    closed_match = re.search(r"<think>(.*?)</think>", raw, re.DOTALL)
    if closed_match:
        reasoning = closed_match.group(1).strip()
        answer = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip()
        return answer, reasoning

    # Unclosed <think> block — generation was cut off during reasoning
    open_match = re.search(r"<think>(.*)", raw, re.DOTALL)
    if open_match:
        return "", open_match.group(1).strip()

    return raw.strip(), ""


def ask(
    prompt: str,
    max_new_tokens: int = 512,
    temperature: float = 0.0,
    do_sample: bool = False,
) -> dict:
    """
    Run a KnowForge prompt through the model.

    Args:
        prompt: Full user-turn text, e.g. "ZELPH RULES: ... Question: ..."
        max_new_tokens: Max tokens to generate.
        temperature: Sampling temperature (ignored when do_sample=False).
        do_sample: Enable sampling; set True + temperature>0 for stochastic output.

    Returns:
        {
            "answer":    str  — text after </think> (or full output if no think block),
            "reasoning": str  — text inside <think>...</think>, empty string if absent,
        }
    """
    pipe = _load_pipeline()
    messages = [
        {
            "role": "system",
            "content": (
                "You are given rules for a fictional system that does NOT exist in the "
                "real world. Reason STRICTLY from the rules provided. Do NOT use any "
                "outside knowledge. Show your reasoning inside <think>...</think> tags "
                "before giving your final answer."
            ),
        },
        {"role": "user", "content": prompt},
    ]

    gen_kwargs: dict = {
        "max_new_tokens": max_new_tokens,
        "do_sample": do_sample,
        "pad_token_id": pipe.tokenizer.eos_token_id,
    }
    if do_sample:
        gen_kwargs["temperature"] = temperature

    outputs = pipe(messages, **gen_kwargs)

    raw = outputs[0]["generated_text"][-1]["content"]
    answer, reasoning = _split_think(raw)
    return {"answer": answer, "reasoning": reasoning}


def _main():
    if len(sys.argv) < 2:
        print("Usage: python inference.py \"ZELPH RULES: ... Question: ...\"")
        sys.exit(1)
    prompt = " ".join(sys.argv[1:])
    result = ask(prompt)
    print("Answer:", result["answer"])
    if result["reasoning"]:
        print("\nReasoning:")
        print(result["reasoning"])


if __name__ == "__main__":
    _main()