CogNet-1B / hf_scripts /infer_optimized.py
thefinalboss's picture
Upload hf_scripts/infer_optimized.py with huggingface_hub
ce094de verified
Raw
History Blame Contribute Delete
9.3 kB
"""
CogNet Optimized Inference Engine
==================================
Inference for the optimized CogNet model.
Supports: generate, analyze, benchmark
Usage:
python infer_optimized.py generate --prompt "The future of AI is" --max-tokens 100
python infer_optimized.py analyze --prompt "CogNet is"
python infer_optimized.py benchmark
"""
import argparse
import json
import math
import os
import sys
import time
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from cognet_1b_optimized import CogNet1BOptimized, create_cognet_1b_optimized
# ─── Model & Tokenizer Loading ───────────────────────────────────────────────
CKPT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'checkpoints')
TOKENIZER_PATH = os.path.join(CKPT_DIR, 'bpe_tokenizer_32000.json')
_model_cache = {'model': None, 'tokenizer': None, 'device': None, 'loaded': False}
def load_model_and_tokenizer(checkpoint_path: Optional[str] = None):
"""Load model and tokenizer with caching."""
if _model_cache['loaded']:
return _model_cache['model'], _model_cache['tokenizer'], _model_cache['device']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load tokenizer
tokenizer = None
if os.path.exists(TOKENIZER_PATH):
try:
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
print(f"Loaded BPE tokenizer (vocab={tokenizer.get_vocab_size()})")
except ImportError:
pass
if tokenizer is None:
# Fallback: simple char tokenizer
print("Using fallback character tokenizer")
tokenizer = _SimpleCharTokenizer()
vocab_size = tokenizer.get_vocab_size() if hasattr(tokenizer, 'get_vocab_size') else tokenizer.vocab_size
# Create model
model = create_cognet_1b_optimized(
vocab_size=vocab_size,
max_seq_len=4096,
use_gradient_checkpointing=False,
)
# Load weights
if checkpoint_path is None:
checkpoint_path = os.path.join(CKPT_DIR, 'best.pt')
if not os.path.exists(checkpoint_path):
checkpoint_path = os.path.join(CKPT_DIR, 'latest.pt')
if checkpoint_path and os.path.exists(checkpoint_path):
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
step = ckpt.get('step', ckpt.get('metrics', {}).get('step', '?'))
print(f"Loaded model from {checkpoint_path} (step={step})")
else:
print("WARNING: No trained weights found. Using random initialization.")
model = model.to(device)
model.eval()
# Try to compile for inference
try:
model = torch.compile(model, mode="reduce-overhead")
print("Model compiled for inference")
except Exception:
pass
_model_cache['model'] = model
_model_cache['tokenizer'] = tokenizer
_model_cache['device'] = device
_model_cache['loaded'] = True
return model, tokenizer, device
class _SimpleCharTokenizer:
"""Fallback character tokenizer."""
def __init__(self, vocab_size=256):
self.vocab_size = vocab_size
self._id_to_char = {i: chr(i) for i in range(min(vocab_size, 256))}
self._char_to_id = {v: k for k, v in self._id_to_char.items()}
def encode(self, text):
return [self._char_to_id.get(c, 0) for c in text]
def decode(self, ids):
return ''.join(self._id_to_char.get(i, ' ') for i in ids)
def get_vocab_size(self):
return self.vocab_size
# ─── Actions ──────────────────────────────────────────────────────────────────
def handle_generate(prompt: str, max_tokens: int = 100,
temperature: float = 0.8, top_k: int = 40) -> Dict:
"""Generate text from a prompt."""
model, tokenizer, device = load_model_and_tokenizer()
ids = tokenizer.encode(prompt)
if not isinstance(ids, list):
ids = ids.ids if hasattr(ids, 'ids') else list(ids)
if len(ids) == 0:
ids = [0]
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
t0 = time.time()
with torch.no_grad():
output_ids = model.generate(
input_ids, max_new_tokens=max_tokens,
temperature=temperature, top_k=top_k,
)
elapsed = time.time() - t0
gen_ids = output_ids[0].tolist()
gen_text = tokenizer.decode(gen_ids) if hasattr(tokenizer, 'decode') else str(gen_ids[:100])
return {
'action': 'generate',
'prompt': prompt,
'generated_text': gen_text,
'num_tokens': len(gen_ids),
'time_seconds': elapsed,
'tokens_per_second': len(gen_ids) / max(elapsed, 0.001),
}
def handle_analyze(prompt: str) -> Dict:
"""Analyze logits, entropy, and top predictions."""
model, tokenizer, device = load_model_and_tokenizer()
ids = tokenizer.encode(prompt)
if not isinstance(ids, list):
ids = ids.ids if hasattr(ids, 'ids') else list(ids)
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
with torch.no_grad():
result = model(input_ids, return_stats=True)
logits = result['logits']
# Last token predictions
last_logits = logits[0, -1, :]
probs = F.softmax(last_logits, dim=-1)
entropy = -(probs * (probs + 1e-10).log()).sum().item()
# Top 10 predictions
topk_vals, topk_ids = torch.topk(probs, min(10, probs.size(0)))
top_predictions = []
for prob, tid in zip(topk_vals.tolist(), topk_ids.tolist()):
char = tokenizer.decode([tid]) if hasattr(tokenizer, 'decode') else f'token_{tid}'
top_predictions.append({
'token_id': tid,
'char': char,
'probability': prob,
})
return {
'action': 'analyze',
'prompt': prompt,
'entropy': entropy,
'top_predictions': top_predictions,
}
def handle_benchmark() -> Dict:
"""Benchmark model throughput."""
model, tokenizer, device = load_model_and_tokenizer()
vocab_size = tokenizer.get_vocab_size() if hasattr(tokenizer, 'get_vocab_size') else tokenizer.vocab_size
# Param count
params = sum(p.numel() for p in model.parameters())
# Warmup
warmup_input = torch.randint(0, vocab_size, (1, 128), device=device)
with torch.no_grad():
for _ in range(5):
model(warmup_input)
# Benchmark different sequence lengths
results = {}
for seq_len in [128, 256, 512, 1024, 2048]:
try:
input_ids = torch.randint(0, vocab_size, (1, seq_len), device=device)
# Timed runs
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.time()
n_runs = 20
with torch.no_grad():
for _ in range(n_runs):
model(input_ids)
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = time.time() - t0
tokens_per_sec = (seq_len * n_runs) / elapsed
latency_ms = (elapsed / n_runs) * 1000
results[f'seq_{seq_len}'] = {
'tokens_per_second': tokens_per_sec,
'latency_ms': latency_ms,
}
print(f" seq_len={seq_len:>5d}: {tokens_per_sec:>10,.0f} tokens/s, {latency_ms:.1f}ms latency")
except torch.cuda.OutOfMemoryError:
results[f'seq_{seq_len}'] = {'error': 'OOM'}
print(f" seq_len={seq_len:>5d}: OOM")
return {
'action': 'benchmark',
'parameters': params,
'device': str(device),
'results': results,
}
# ─── CLI ──────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description='CogNet Optimized Inference')
parser.add_argument('action', choices=['generate', 'analyze', 'benchmark'])
parser.add_argument('--prompt', type=str, default='The ')
parser.add_argument('--max-tokens', type=int, default=100)
parser.add_argument('--temperature', type=float, default=0.8)
parser.add_argument('--top-k', type=int, default=40)
parser.add_argument('--checkpoint', type=str, default=None)
args = parser.parse_args()
if args.checkpoint:
# Override checkpoint path
global TOKENIZER_PATH
# Load from specified checkpoint
_model_cache['loaded'] = False
if args.action == 'generate':
result = handle_generate(args.prompt, args.max_tokens, args.temperature, args.top_k)
elif args.action == 'analyze':
result = handle_analyze(args.prompt)
elif args.action == 'benchmark':
result = handle_benchmark()
print(json.dumps(result, indent=2, ensure_ascii=False, default=str))
if __name__ == '__main__':
main()