| """ |
| CogNet Optimized Inference Engine |
| ================================== |
| Inference for the optimized CogNet model. |
| Supports: generate, analyze, benchmark |
| |
| Usage: |
| python infer_optimized.py generate --prompt "The future of AI is" --max-tokens 100 |
| python infer_optimized.py analyze --prompt "CogNet is" |
| python infer_optimized.py benchmark |
| """ |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import sys |
| import time |
| from typing import Dict, List, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from cognet_1b_optimized import CogNet1BOptimized, create_cognet_1b_optimized |
|
|
|
|
| |
|
|
| CKPT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'checkpoints') |
| TOKENIZER_PATH = os.path.join(CKPT_DIR, 'bpe_tokenizer_32000.json') |
|
|
|
|
| _model_cache = {'model': None, 'tokenizer': None, 'device': None, 'loaded': False} |
|
|
|
|
| def load_model_and_tokenizer(checkpoint_path: Optional[str] = None): |
| """Load model and tokenizer with caching.""" |
| if _model_cache['loaded']: |
| return _model_cache['model'], _model_cache['tokenizer'], _model_cache['device'] |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| tokenizer = None |
| if os.path.exists(TOKENIZER_PATH): |
| try: |
| from tokenizers import Tokenizer |
| tokenizer = Tokenizer.from_file(TOKENIZER_PATH) |
| print(f"Loaded BPE tokenizer (vocab={tokenizer.get_vocab_size()})") |
| except ImportError: |
| pass |
|
|
| if tokenizer is None: |
| |
| print("Using fallback character tokenizer") |
| tokenizer = _SimpleCharTokenizer() |
|
|
| vocab_size = tokenizer.get_vocab_size() if hasattr(tokenizer, 'get_vocab_size') else tokenizer.vocab_size |
|
|
| |
| model = create_cognet_1b_optimized( |
| vocab_size=vocab_size, |
| max_seq_len=4096, |
| use_gradient_checkpointing=False, |
| ) |
|
|
| |
| if checkpoint_path is None: |
| checkpoint_path = os.path.join(CKPT_DIR, 'best.pt') |
| if not os.path.exists(checkpoint_path): |
| checkpoint_path = os.path.join(CKPT_DIR, 'latest.pt') |
|
|
| if checkpoint_path and os.path.exists(checkpoint_path): |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| model.load_state_dict(ckpt['model_state_dict']) |
| step = ckpt.get('step', ckpt.get('metrics', {}).get('step', '?')) |
| print(f"Loaded model from {checkpoint_path} (step={step})") |
| else: |
| print("WARNING: No trained weights found. Using random initialization.") |
|
|
| model = model.to(device) |
| model.eval() |
|
|
| |
| try: |
| model = torch.compile(model, mode="reduce-overhead") |
| print("Model compiled for inference") |
| except Exception: |
| pass |
|
|
| _model_cache['model'] = model |
| _model_cache['tokenizer'] = tokenizer |
| _model_cache['device'] = device |
| _model_cache['loaded'] = True |
|
|
| return model, tokenizer, device |
|
|
|
|
| class _SimpleCharTokenizer: |
| """Fallback character tokenizer.""" |
| def __init__(self, vocab_size=256): |
| self.vocab_size = vocab_size |
| self._id_to_char = {i: chr(i) for i in range(min(vocab_size, 256))} |
| self._char_to_id = {v: k for k, v in self._id_to_char.items()} |
|
|
| def encode(self, text): |
| return [self._char_to_id.get(c, 0) for c in text] |
|
|
| def decode(self, ids): |
| return ''.join(self._id_to_char.get(i, ' ') for i in ids) |
|
|
| def get_vocab_size(self): |
| return self.vocab_size |
|
|
|
|
| |
|
|
| def handle_generate(prompt: str, max_tokens: int = 100, |
| temperature: float = 0.8, top_k: int = 40) -> Dict: |
| """Generate text from a prompt.""" |
| model, tokenizer, device = load_model_and_tokenizer() |
|
|
| ids = tokenizer.encode(prompt) |
| if not isinstance(ids, list): |
| ids = ids.ids if hasattr(ids, 'ids') else list(ids) |
| if len(ids) == 0: |
| ids = [0] |
|
|
| input_ids = torch.tensor([ids], dtype=torch.long, device=device) |
|
|
| t0 = time.time() |
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_ids, max_new_tokens=max_tokens, |
| temperature=temperature, top_k=top_k, |
| ) |
| elapsed = time.time() - t0 |
|
|
| gen_ids = output_ids[0].tolist() |
| gen_text = tokenizer.decode(gen_ids) if hasattr(tokenizer, 'decode') else str(gen_ids[:100]) |
|
|
| return { |
| 'action': 'generate', |
| 'prompt': prompt, |
| 'generated_text': gen_text, |
| 'num_tokens': len(gen_ids), |
| 'time_seconds': elapsed, |
| 'tokens_per_second': len(gen_ids) / max(elapsed, 0.001), |
| } |
|
|
|
|
| def handle_analyze(prompt: str) -> Dict: |
| """Analyze logits, entropy, and top predictions.""" |
| model, tokenizer, device = load_model_and_tokenizer() |
|
|
| ids = tokenizer.encode(prompt) |
| if not isinstance(ids, list): |
| ids = ids.ids if hasattr(ids, 'ids') else list(ids) |
|
|
| input_ids = torch.tensor([ids], dtype=torch.long, device=device) |
|
|
| with torch.no_grad(): |
| result = model(input_ids, return_stats=True) |
| logits = result['logits'] |
|
|
| |
| last_logits = logits[0, -1, :] |
| probs = F.softmax(last_logits, dim=-1) |
| entropy = -(probs * (probs + 1e-10).log()).sum().item() |
|
|
| |
| topk_vals, topk_ids = torch.topk(probs, min(10, probs.size(0))) |
| top_predictions = [] |
| for prob, tid in zip(topk_vals.tolist(), topk_ids.tolist()): |
| char = tokenizer.decode([tid]) if hasattr(tokenizer, 'decode') else f'token_{tid}' |
| top_predictions.append({ |
| 'token_id': tid, |
| 'char': char, |
| 'probability': prob, |
| }) |
|
|
| return { |
| 'action': 'analyze', |
| 'prompt': prompt, |
| 'entropy': entropy, |
| 'top_predictions': top_predictions, |
| } |
|
|
|
|
| def handle_benchmark() -> Dict: |
| """Benchmark model throughput.""" |
| model, tokenizer, device = load_model_and_tokenizer() |
| vocab_size = tokenizer.get_vocab_size() if hasattr(tokenizer, 'get_vocab_size') else tokenizer.vocab_size |
|
|
| |
| params = sum(p.numel() for p in model.parameters()) |
|
|
| |
| warmup_input = torch.randint(0, vocab_size, (1, 128), device=device) |
| with torch.no_grad(): |
| for _ in range(5): |
| model(warmup_input) |
|
|
| |
| results = {} |
| for seq_len in [128, 256, 512, 1024, 2048]: |
| try: |
| input_ids = torch.randint(0, vocab_size, (1, seq_len), device=device) |
|
|
| |
| torch.cuda.synchronize() if torch.cuda.is_available() else None |
| t0 = time.time() |
|
|
| n_runs = 20 |
| with torch.no_grad(): |
| for _ in range(n_runs): |
| model(input_ids) |
|
|
| torch.cuda.synchronize() if torch.cuda.is_available() else None |
| elapsed = time.time() - t0 |
|
|
| tokens_per_sec = (seq_len * n_runs) / elapsed |
| latency_ms = (elapsed / n_runs) * 1000 |
|
|
| results[f'seq_{seq_len}'] = { |
| 'tokens_per_second': tokens_per_sec, |
| 'latency_ms': latency_ms, |
| } |
| print(f" seq_len={seq_len:>5d}: {tokens_per_sec:>10,.0f} tokens/s, {latency_ms:.1f}ms latency") |
| except torch.cuda.OutOfMemoryError: |
| results[f'seq_{seq_len}'] = {'error': 'OOM'} |
| print(f" seq_len={seq_len:>5d}: OOM") |
|
|
| return { |
| 'action': 'benchmark', |
| 'parameters': params, |
| 'device': str(device), |
| 'results': results, |
| } |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='CogNet Optimized Inference') |
| parser.add_argument('action', choices=['generate', 'analyze', 'benchmark']) |
| parser.add_argument('--prompt', type=str, default='The ') |
| parser.add_argument('--max-tokens', type=int, default=100) |
| parser.add_argument('--temperature', type=float, default=0.8) |
| parser.add_argument('--top-k', type=int, default=40) |
| parser.add_argument('--checkpoint', type=str, default=None) |
| args = parser.parse_args() |
|
|
| if args.checkpoint: |
| |
| global TOKENIZER_PATH |
| |
| _model_cache['loaded'] = False |
|
|
| if args.action == 'generate': |
| result = handle_generate(args.prompt, args.max_tokens, args.temperature, args.top_k) |
| elif args.action == 'analyze': |
| result = handle_analyze(args.prompt) |
| elif args.action == 'benchmark': |
| result = handle_benchmark() |
|
|
| print(json.dumps(result, indent=2, ensure_ascii=False, default=str)) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|