File size: 6,457 Bytes
578c1ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """
KnowForge Encoder β standalone inference.
Predicts transform_type and answer_type from a KnowForge input prompt.
CLI: python inference.py "A cao hΖ‘n B, B cao hΖ‘n C. A cΓ³ cao hΖ‘n C khΓ΄ng?"
API: from inference import predict; result = predict("A cao hΖ‘n B...")
"""
import json
import re
import sys
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
_HERE = Path(__file__).parent
# ββ Label maps (must match training) ββββββββββββββββββββββββββββββββββββββββ
TRANSFORM_LABELS = ["linear_to_cyclic", "relation_property_check", "relation_to_graph"]
ATYPE_LABELS = ["conditional_answer", "exact_answer", "need_more_rule",
"unresolvable_without_observation"]
# ββ Tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_TOK_RE = re.compile(r"[\w]+|[^\w\s]", re.UNICODE)
def _tokenize(text: str) -> list:
return _TOK_RE.findall(text.lower())
# ββ Model architecture βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class _MultiTaskEncoder(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int = 64,
hidden_dim: int = 64, n_layers: int = 2, dropout: float = 0.3):
super().__init__()
enc_dim = hidden_dim * 2 # 128
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.dropout = nn.Dropout(dropout)
conv_layers = []
in_ch = embed_dim
for _ in range(n_layers):
conv_layers += [nn.Conv1d(in_ch, enc_dim, 3, padding=1), nn.ReLU()]
in_ch = enc_dim
self.encoder = nn.Sequential(*conv_layers)
self.transform_head = nn.Linear(enc_dim, len(TRANSFORM_LABELS))
self.atype_head = nn.Linear(enc_dim, len(ATYPE_LABELS))
# Unused heads included so state_dict keys match exactly
self.etype_head = nn.Linear(enc_dim, 24)
self.uncertainty_head = nn.Linear(enc_dim, 5)
self.bio_head = nn.Linear(enc_dim, 12)
def forward(self, token_ids: torch.Tensor) -> dict:
x = self.embedding(token_ids) # (B, L, E)
x = self.dropout(x)
out = self.encoder(x.transpose(1, 2)).transpose(1, 2) # (B, L, 128)
# Global max pooling over sequence dim
pooled = out.max(dim=1).values # (B, 128)
return {
"transform": self.transform_head(pooled),
"atype": self.atype_head(pooled),
}
# ββ Lazy singleton loader ββββββββββββββββββββββββββββββββββββββββββββββββββββ
_encoder: Optional[_MultiTaskEncoder] = None
_vocab: Optional[dict] = None
def _load():
global _encoder, _vocab
if _encoder is not None:
return _encoder, _vocab
vocab_path = _HERE / "vocab.json"
cfg_path = _HERE / "model_config.json"
sf_path = _HERE / "best_model.safetensors"
pt_path = _HERE / "best_model.pt"
if not vocab_path.exists():
raise FileNotFoundError(f"vocab.json not found at {vocab_path}")
_vocab = json.load(open(vocab_path))
cfg = json.load(open(cfg_path)) if cfg_path.exists() else {}
model = _MultiTaskEncoder(
vocab_size = cfg.get("vocab_size", len(_vocab)),
embed_dim = cfg.get("embed_dim", 64),
hidden_dim = cfg.get("hidden_dim", 64),
n_layers = cfg.get("n_layers", 2),
dropout = cfg.get("dropout", 0.3),
)
if sf_path.exists():
from safetensors.torch import load_file
state = load_file(str(sf_path))
elif pt_path.exists():
state = torch.load(str(pt_path), map_location="cpu", weights_only=True)
else:
raise FileNotFoundError(f"No model weights found at {sf_path} or {pt_path}")
model.load_state_dict(state)
model.eval()
_encoder = model
return _encoder, _vocab
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict(text: str) -> dict:
"""
Predict transform_type and answer_type for a KnowForge input.
Args:
text: Natural-language input (rules + question or question alone).
Returns:
{
"transform_type": str β one of linear_to_cyclic /
relation_property_check /
relation_to_graph,
"transform_confidence": float β softmax probability [0,1],
"answer_type": str β one of conditional_answer /
exact_answer /
need_more_rule /
unresolvable_without_observation,
"atype_confidence": float,
}
"""
model, vocab = _load()
toks = _tokenize(text)
ids = [vocab.get(t, vocab.get("<UNK>", 1)) for t in toks] or [0]
x = torch.tensor([ids], dtype=torch.long) # (1, L)
with torch.no_grad():
logits = model(x)
t_probs = F.softmax(logits["transform"][0], dim=-1)
a_probs = F.softmax(logits["atype"][0], dim=-1)
t_idx = int(t_probs.argmax())
a_idx = int(a_probs.argmax())
return {
"transform_type": TRANSFORM_LABELS[t_idx],
"transform_confidence": round(float(t_probs[t_idx]), 4),
"answer_type": ATYPE_LABELS[a_idx],
"atype_confidence": round(float(a_probs[a_idx]), 4),
}
def _main():
if len(sys.argv) < 2:
print("Usage: python inference.py \"<input text>\"")
sys.exit(1)
text = " ".join(sys.argv[1:])
result = predict(text)
print(f"Transform: {result['transform_type']} ({result['transform_confidence']:.2%})")
print(f"Answer type: {result['answer_type']} ({result['atype_confidence']:.2%})")
if __name__ == "__main__":
_main()
|