|
|
|
|
|
""" |
|
|
K-Simplex Language Model - Inference Script |
|
|
|
|
|
Loads a trained k-simplex LLM checkpoint and generates text using |
|
|
geometrically-validated autoregressive sampling. |
|
|
|
|
|
Usage: |
|
|
python inference.py --checkpoint checkpoint_epoch_008.pt --prompt "ROMEO: " |
|
|
python inference.py --repo AbstractPhil/ksimplex-llm-prototype --prompt "To be or not" |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import tiktoken |
|
|
from pathlib import Path |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def factorial(n: int) -> int: |
|
|
return math.factorial(n) |
|
|
|
|
|
|
|
|
def cayley_menger_volume_squared(vertices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Compute squared volume via Cayley-Menger determinant. |
|
|
|
|
|
Args: |
|
|
vertices: [*, nv, edim] vertex coordinates |
|
|
|
|
|
Returns: |
|
|
d2: [*, n_pairs] squared distances |
|
|
vol2: [*] squared volume |
|
|
""" |
|
|
nv = vertices.shape[-2] |
|
|
k = nv - 1 |
|
|
|
|
|
|
|
|
diff = vertices.unsqueeze(-2) - vertices.unsqueeze(-3) |
|
|
d2_matrix = (diff ** 2).sum(-1) |
|
|
|
|
|
|
|
|
idx = torch.triu_indices(nv, nv, offset=1) |
|
|
d2 = d2_matrix[..., idx[0], idx[1]] |
|
|
|
|
|
|
|
|
batch_shape = vertices.shape[:-2] |
|
|
size = nv + 1 |
|
|
cm = torch.zeros(*batch_shape, size, size, device=vertices.device, dtype=vertices.dtype) |
|
|
|
|
|
|
|
|
cm[..., 0, 1:] = 1.0 |
|
|
cm[..., 1:, 0] = 1.0 |
|
|
|
|
|
|
|
|
cm[..., 1:, 1:] = d2_matrix |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
det = torch.linalg.det(cm) |
|
|
|
|
|
|
|
|
sign = (-1) ** (k + 1) |
|
|
denom = (2 ** k) * (factorial(k) ** 2) |
|
|
vol2 = sign * det / denom |
|
|
|
|
|
return d2, vol2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimplexTemplate(nn.Module): |
|
|
"""Generates regular simplex template vertices.""" |
|
|
|
|
|
def __init__(self, k: int, edim: int, scale: float = 1.0): |
|
|
super().__init__() |
|
|
self.k = k |
|
|
self.nv = k + 1 |
|
|
self.edim = edim |
|
|
|
|
|
|
|
|
vertices = torch.zeros(self.nv, edim) |
|
|
for i in range(self.nv): |
|
|
angle = 2 * math.pi * i / self.nv |
|
|
vertices[i, 0] = scale * math.cos(angle) |
|
|
if edim > 1: |
|
|
vertices[i, 1] = scale * math.sin(angle) |
|
|
if edim > 2: |
|
|
vertices[i, 2] = scale * 0.3 * math.cos(angle * 2) |
|
|
for d in range(3, edim): |
|
|
vertices[i, d] = scale * 0.1 * math.sin(angle * (d + 1)) |
|
|
|
|
|
self.register_buffer('template', vertices) |
|
|
|
|
|
def forward(self) -> torch.Tensor: |
|
|
return self.template |
|
|
|
|
|
|
|
|
class KSimplexChannel(nn.Module): |
|
|
"""Single k-simplex channel with geometric validation.""" |
|
|
|
|
|
def __init__(self, k: int, edim: int, hidden: int, feat_dim: int, base_deform: float = 0.05): |
|
|
super().__init__() |
|
|
self.k = k |
|
|
self.nv = k + 1 |
|
|
self.edim = edim |
|
|
self.feat_dim = feat_dim |
|
|
self.base_deform = base_deform |
|
|
|
|
|
|
|
|
self.template = SimplexTemplate(k, edim) |
|
|
|
|
|
|
|
|
self._to_coords = nn.Linear(hidden, self.nv * edim) |
|
|
self._to_feats = nn.Linear(hidden, self.nv * feat_dim) |
|
|
|
|
|
|
|
|
n_pairs = (self.nv * (self.nv - 1)) // 2 |
|
|
self.geo_dim = n_pairs + 1 |
|
|
|
|
|
|
|
|
self._geo_gate = nn.Sequential( |
|
|
nn.Linear(self.geo_dim, feat_dim), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
x: [*, hidden] |
|
|
|
|
|
Returns: |
|
|
out: [*, feat_dim + geo_dim] gated features + geometry |
|
|
vol2: [*] squared volume for validity loss |
|
|
mean_d2: [*] mean squared distance |
|
|
""" |
|
|
|
|
|
coords = self._to_coords(x).unflatten(-1, (self.nv, self.edim)) |
|
|
verts = self.template() + self.base_deform * coords |
|
|
|
|
|
|
|
|
vert_feats = self._to_feats(x).unflatten(-1, (self.nv, self.feat_dim)) |
|
|
|
|
|
|
|
|
d2, vol2 = cayley_menger_volume_squared(verts) |
|
|
|
|
|
|
|
|
geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
|
|
|
gate = self._geo_gate(geo) |
|
|
validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1) |
|
|
|
|
|
|
|
|
feat_agg = vert_feats.mean(dim=-2) * gate * validity |
|
|
|
|
|
|
|
|
out = torch.cat([feat_agg, geo], dim=-1) |
|
|
|
|
|
return out, vol2, d2.mean(dim=-1) |
|
|
|
|
|
|
|
|
class TokenToKChannels(nn.Module): |
|
|
"""Project token embeddings to k-simplex channels.""" |
|
|
|
|
|
def __init__(self, embed_dim: int, hidden: int, depth: int, edim: int, feat_dim: int): |
|
|
super().__init__() |
|
|
self.depth = depth |
|
|
|
|
|
self._proj = nn.Linear(embed_dim, hidden) |
|
|
self._channels = nn.ModuleList([ |
|
|
KSimplexChannel(k=k+1, edim=edim, hidden=hidden, feat_dim=feat_dim) |
|
|
for k in range(depth) |
|
|
]) |
|
|
|
|
|
|
|
|
self.out_dims = [ch.feat_dim + ch.geo_dim for ch in self._channels] |
|
|
self.max_dim = max(self.out_dims) |
|
|
|
|
|
|
|
|
self._pads = nn.ModuleList([ |
|
|
nn.Linear(d, self.max_dim) if d != self.max_dim else nn.Identity() |
|
|
for d in self.out_dims |
|
|
]) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]: |
|
|
""" |
|
|
Args: |
|
|
x: [B, T, embed_dim] |
|
|
|
|
|
Returns: |
|
|
out: [B, T, K, max_dim] |
|
|
vol2_list: list of [B, T] per k |
|
|
d2_list: list of [B, T] per k |
|
|
""" |
|
|
h = self._proj(x) |
|
|
|
|
|
outputs = [] |
|
|
vol2_list = [] |
|
|
d2_list = [] |
|
|
|
|
|
for ch, pad in zip(self._channels, self._pads): |
|
|
out, vol2, d2 = ch(h) |
|
|
outputs.append(pad(out)) |
|
|
vol2_list.append(vol2) |
|
|
d2_list.append(d2) |
|
|
|
|
|
|
|
|
out = torch.stack(outputs, dim=-2) |
|
|
|
|
|
return out, vol2_list, d2_list |
|
|
|
|
|
|
|
|
class KChannelCrossAttention(nn.Module): |
|
|
"""Cross-attention between k-levels at each position.""" |
|
|
|
|
|
def __init__(self, dim: int, num_heads: int = 4, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: [B, T, K, D] |
|
|
Returns: |
|
|
[B, T, K, D] |
|
|
""" |
|
|
B, T, K, D = x.shape |
|
|
|
|
|
|
|
|
x_flat = x.view(B * T, K, D) |
|
|
|
|
|
|
|
|
attn_out, _ = self.attn(x_flat, x_flat, x_flat) |
|
|
|
|
|
|
|
|
out = self.norm(x_flat + attn_out) |
|
|
|
|
|
return out.view(B, T, K, D) |
|
|
|
|
|
|
|
|
class CausalSequenceAttention(nn.Module): |
|
|
"""Causal attention across sequence positions.""" |
|
|
|
|
|
def __init__(self, dim: int, num_heads: int, max_seq_len: int, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
|
|
|
|
|
mask = torch.tril(torch.ones(max_seq_len, max_seq_len)).bool() |
|
|
self.register_buffer('_causal_mask', mask) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: [B, T, K, D] |
|
|
Returns: |
|
|
[B, T, K, D] |
|
|
""" |
|
|
B, T, K, D = x.shape |
|
|
|
|
|
|
|
|
x_flat = x.view(B, T, K * D) |
|
|
|
|
|
|
|
|
mask = self._causal_mask[:T, :T] |
|
|
attn_mask = ~mask |
|
|
|
|
|
|
|
|
attn_out, _ = self.attn( |
|
|
x_flat, x_flat, x_flat, |
|
|
attn_mask=attn_mask.float().masked_fill(attn_mask, float('-inf')) |
|
|
) |
|
|
|
|
|
|
|
|
out = self.norm(x_flat + attn_out) |
|
|
|
|
|
return out.view(B, T, K, D) |
|
|
|
|
|
|
|
|
class GeoBlock(nn.Module): |
|
|
"""Geometric block: k-channel attention + causal sequence attention + MLP.""" |
|
|
|
|
|
def __init__(self, dim: int, num_heads: int, max_seq_len: int, depth: int, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.k_attn = KChannelCrossAttention(dim, num_heads=4, dropout=dropout) |
|
|
self.seq_attn = CausalSequenceAttention(dim, num_heads, max_seq_len, dropout) |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim * depth, dim * depth * 4), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(dim * depth * 4, dim * depth), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
self.mlp_norm = nn.LayerNorm(dim * depth) |
|
|
self.depth = depth |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: [B, T, K, D] |
|
|
Returns: |
|
|
[B, T, K, D] |
|
|
""" |
|
|
|
|
|
x = self.k_attn(x) |
|
|
|
|
|
|
|
|
x = self.seq_attn(x) |
|
|
|
|
|
|
|
|
B, T, K, D = x.shape |
|
|
x_flat = x.view(B, T, K * D) |
|
|
x_flat = self.mlp_norm(x_flat + self.mlp(x_flat)) |
|
|
|
|
|
return x_flat.view(B, T, K, D) |
|
|
|
|
|
|
|
|
class KSimplexLM(nn.Module): |
|
|
"""K-Simplex Language Model.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 50257, |
|
|
max_seq_len: int = 256, |
|
|
embed_dim: int = 384, |
|
|
depth: int = 4, |
|
|
edim: int = 16, |
|
|
feat_dim: int = 96, |
|
|
hidden: int = 384, |
|
|
num_heads: int = 8, |
|
|
num_blocks: int = 8, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.vocab_size = vocab_size |
|
|
self.max_seq_len = max_seq_len |
|
|
self.depth = depth |
|
|
|
|
|
|
|
|
self.embed = nn.Embedding(vocab_size, embed_dim) |
|
|
self.pos_embed = nn.Embedding(max_seq_len, embed_dim) |
|
|
self.embed_drop = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.to_k_channels = TokenToKChannels(embed_dim, hidden, depth, edim, feat_dim) |
|
|
|
|
|
|
|
|
k_dim = self.to_k_channels.max_dim |
|
|
self.blocks = nn.ModuleList([ |
|
|
GeoBlock(k_dim, num_heads, max_seq_len, depth, dropout) |
|
|
for _ in range(num_blocks) |
|
|
]) |
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(k_dim * depth) |
|
|
self.lm_head = nn.Linear(k_dim * depth, vocab_size, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.Embedding): |
|
|
nn.init.normal_(m.weight, std=0.02) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]: |
|
|
""" |
|
|
Args: |
|
|
x: [B, T] token indices |
|
|
|
|
|
Returns: |
|
|
logits: [B, T, vocab_size] |
|
|
geo_info: dict with vol2, d2 per k-level |
|
|
""" |
|
|
B, T = x.shape |
|
|
|
|
|
|
|
|
pos = torch.arange(T, device=x.device).unsqueeze(0) |
|
|
h = self.embed(x) + self.pos_embed(pos) |
|
|
h = self.embed_drop(h) |
|
|
|
|
|
|
|
|
h, vol2_list, d2_list = self.to_k_channels(h) |
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
|
h = block(h) |
|
|
|
|
|
|
|
|
h_flat = h.view(B, T, -1) |
|
|
h_flat = self.ln_f(h_flat) |
|
|
logits = self.lm_head(h_flat) |
|
|
|
|
|
geo_info = { |
|
|
'vol2': vol2_list, |
|
|
'd2': d2_list, |
|
|
} |
|
|
|
|
|
return logits, geo_info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model( |
|
|
checkpoint_path: str = None, |
|
|
repo_id: str = None, |
|
|
device: str = None, |
|
|
) -> tuple[KSimplexLM, tiktoken.Encoding]: |
|
|
""" |
|
|
Load model from checkpoint or HuggingFace Hub. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Local path to checkpoint |
|
|
repo_id: HuggingFace repo ID (e.g., "AbstractPhil/ksimplex-llm-prototype") |
|
|
device: Device to load to |
|
|
|
|
|
Returns: |
|
|
model: KSimplexLM |
|
|
tokenizer: tiktoken encoding |
|
|
""" |
|
|
if device is None: |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
if repo_id: |
|
|
checkpoint_path = hf_hub_download(repo_id, "checkpoint_latest.pt") |
|
|
config_path = hf_hub_download(repo_id, "config.json") |
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
elif checkpoint_path: |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
config = checkpoint.get('config', {}).get('model', {}) |
|
|
else: |
|
|
raise ValueError("Must provide checkpoint_path or repo_id") |
|
|
|
|
|
|
|
|
model = KSimplexLM( |
|
|
vocab_size=config.get('vocab_size', 50257), |
|
|
max_seq_len=config.get('max_seq_len', 256), |
|
|
embed_dim=config.get('embed_dim', 384), |
|
|
depth=config.get('depth', 4), |
|
|
edim=config.get('edim', 16), |
|
|
feat_dim=config.get('feat_dim', 96), |
|
|
hidden=config.get('hidden', 384), |
|
|
num_heads=config.get('num_heads', 8), |
|
|
num_blocks=config.get('num_blocks', 8), |
|
|
dropout=0.0, |
|
|
) |
|
|
|
|
|
|
|
|
if repo_id: |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
state_dict = checkpoint.get('model_state_dict', checkpoint) |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
tokenizer = tiktoken.get_encoding("gpt2") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
model: KSimplexLM, |
|
|
tokenizer: tiktoken.Encoding, |
|
|
prompt: str, |
|
|
max_tokens: int = 100, |
|
|
temperature: float = 0.8, |
|
|
top_k: int = 50, |
|
|
top_p: float = 0.9, |
|
|
device: str = None, |
|
|
) -> str: |
|
|
""" |
|
|
Generate text from prompt. |
|
|
|
|
|
Args: |
|
|
model: KSimplexLM model |
|
|
tokenizer: tiktoken encoding |
|
|
prompt: Input text prompt |
|
|
max_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
top_k: Top-k sampling |
|
|
top_p: Nucleus sampling threshold |
|
|
device: Device |
|
|
|
|
|
Returns: |
|
|
Generated text including prompt |
|
|
""" |
|
|
if device is None: |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
|
|
|
tokens = tokenizer.encode(prompt) |
|
|
tokens = torch.tensor([tokens], dtype=torch.long, device=device) |
|
|
|
|
|
|
|
|
for _ in range(max_tokens): |
|
|
|
|
|
if tokens.shape[1] > model.max_seq_len: |
|
|
tokens = tokens[:, -model.max_seq_len:] |
|
|
|
|
|
|
|
|
logits, geo_info = model(tokens) |
|
|
logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
logits = logits / temperature |
|
|
|
|
|
|
|
|
if top_k > 0: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = float('-inf') |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
tokens = torch.cat([tokens, next_token], dim=1) |
|
|
|
|
|
|
|
|
if next_token.item() == tokenizer.eot_token: |
|
|
break |
|
|
|
|
|
|
|
|
return tokenizer.decode(tokens[0].tolist()) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def analyze_geometry( |
|
|
model: KSimplexLM, |
|
|
tokenizer: tiktoken.Encoding, |
|
|
text: str, |
|
|
device: str = None, |
|
|
) -> dict: |
|
|
""" |
|
|
Analyze geometric properties of text encoding. |
|
|
|
|
|
Args: |
|
|
model: KSimplexLM model |
|
|
tokenizer: tiktoken encoding |
|
|
text: Input text |
|
|
device: Device |
|
|
|
|
|
Returns: |
|
|
Dictionary with geometric statistics |
|
|
""" |
|
|
if device is None: |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
tokens = tokenizer.encode(text) |
|
|
tokens = torch.tensor([tokens], dtype=torch.long, device=device) |
|
|
|
|
|
_, geo_info = model(tokens) |
|
|
|
|
|
stats = {} |
|
|
for k, (vol2, d2) in enumerate(zip(geo_info['vol2'], geo_info['d2']), 1): |
|
|
vol2_np = vol2.cpu().numpy() |
|
|
d2_np = d2.cpu().numpy() |
|
|
|
|
|
stats[f'k{k}'] = { |
|
|
'vol2_mean': float(vol2_np.mean()), |
|
|
'vol2_std': float(vol2_np.std()), |
|
|
'vol2_min': float(vol2_np.min()), |
|
|
'vol2_max': float(vol2_np.max()), |
|
|
'validity_rate': float((vol2_np > 0).mean()), |
|
|
'd2_mean': float(d2_np.mean()), |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='K-Simplex LLM Inference') |
|
|
parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file') |
|
|
parser.add_argument('--repo', type=str, default='AbstractPhil/ksimplex-llm-prototype', |
|
|
help='HuggingFace repo ID') |
|
|
parser.add_argument('--prompt', type=str, default='ROMEO: ', |
|
|
help='Text prompt') |
|
|
parser.add_argument('--max_tokens', type=int, default=100, |
|
|
help='Maximum tokens to generate') |
|
|
parser.add_argument('--temperature', type=float, default=0.8, |
|
|
help='Sampling temperature') |
|
|
parser.add_argument('--top_k', type=int, default=50, |
|
|
help='Top-k sampling') |
|
|
parser.add_argument('--top_p', type=float, default=0.9, |
|
|
help='Nucleus sampling threshold') |
|
|
parser.add_argument('--analyze', action='store_true', |
|
|
help='Analyze geometric properties instead of generating') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print("Loading model...") |
|
|
model, tokenizer = load_model( |
|
|
checkpoint_path=args.checkpoint, |
|
|
repo_id=args.repo if not args.checkpoint else None, |
|
|
) |
|
|
print(f"Model loaded on {next(model.parameters()).device}") |
|
|
|
|
|
if args.analyze: |
|
|
print(f"\nAnalyzing: {args.prompt}") |
|
|
stats = analyze_geometry(model, tokenizer, args.prompt) |
|
|
for k, kstats in stats.items(): |
|
|
print(f"\n{k}:") |
|
|
for name, value in kstats.items(): |
|
|
print(f" {name}: {value:.6f}") |
|
|
else: |
|
|
print(f"\nGenerating from: {args.prompt}") |
|
|
text = generate( |
|
|
model, tokenizer, args.prompt, |
|
|
max_tokens=args.max_tokens, |
|
|
temperature=args.temperature, |
|
|
top_k=args.top_k, |
|
|
top_p=args.top_p, |
|
|
) |
|
|
print("\n" + "=" * 60) |
|
|
print(text) |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |