File size: 3,633 Bytes
616bac5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | """
KnowForge-0.6B inference — standalone, no extra deps beyond transformers.
CLI: python inference.py "ZELPH RULES: ... Question: ..."
API: from inference import ask; result = ask("ZELPH RULES: ...")
"""
import re
import sys
from pathlib import Path
from typing import Optional
_MODEL_DIR = Path(__file__).parent
_pipeline = None # lazy singleton
def _load_pipeline():
global _pipeline
if _pipeline is not None:
return _pipeline
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained(str(_MODEL_DIR), trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
str(_MODEL_DIR),
dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
model.eval()
_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
return _pipeline
def _split_think(raw: str) -> tuple[str, str]:
"""Split <think>...</think> from final answer. Returns (answer, reasoning).
Handles truncated output where </think> was never generated.
"""
closed_match = re.search(r"<think>(.*?)</think>", raw, re.DOTALL)
if closed_match:
reasoning = closed_match.group(1).strip()
answer = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip()
return answer, reasoning
# Unclosed <think> block — generation was cut off during reasoning
open_match = re.search(r"<think>(.*)", raw, re.DOTALL)
if open_match:
return "", open_match.group(1).strip()
return raw.strip(), ""
def ask(
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.0,
do_sample: bool = False,
) -> dict:
"""
Run a KnowForge prompt through the model.
Args:
prompt: Full user-turn text, e.g. "ZELPH RULES: ... Question: ..."
max_new_tokens: Max tokens to generate.
temperature: Sampling temperature (ignored when do_sample=False).
do_sample: Enable sampling; set True + temperature>0 for stochastic output.
Returns:
{
"answer": str — text after </think> (or full output if no think block),
"reasoning": str — text inside <think>...</think>, empty string if absent,
}
"""
pipe = _load_pipeline()
messages = [
{
"role": "system",
"content": (
"You are given rules for a fictional system that does NOT exist in the "
"real world. Reason STRICTLY from the rules provided. Do NOT use any "
"outside knowledge. Show your reasoning inside <think>...</think> tags "
"before giving your final answer."
),
},
{"role": "user", "content": prompt},
]
gen_kwargs: dict = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"pad_token_id": pipe.tokenizer.eos_token_id,
}
if do_sample:
gen_kwargs["temperature"] = temperature
outputs = pipe(messages, **gen_kwargs)
raw = outputs[0]["generated_text"][-1]["content"]
answer, reasoning = _split_think(raw)
return {"answer": answer, "reasoning": reasoning}
def _main():
if len(sys.argv) < 2:
print("Usage: python inference.py \"ZELPH RULES: ... Question: ...\"")
sys.exit(1)
prompt = " ".join(sys.argv[1:])
result = ask(prompt)
print("Answer:", result["answer"])
if result["reasoning"]:
print("\nReasoning:")
print(result["reasoning"])
if __name__ == "__main__":
_main()
|