""" Model inference with logits extraction. Uses GPT-2 (small/medium) — fully open, no API key required. Runs inside HuggingFace Space CPU environment. """ import numpy as np import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast from typing import Tuple, List, Optional from selection_entropy import compute_token_entropies, TokenEntropy # Singleton model cache _model_cache = {} def load_model(model_name: str = "gpt2") -> Tuple: """Load and cache model + tokenizer.""" if model_name not in _model_cache: print(f"Loading {model_name}...") tokenizer = GPT2TokenizerFast.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) model.eval() _model_cache[model_name] = (model, tokenizer) return _model_cache[model_name] def generate_with_logits( prompt: str, model_name: str = "gpt2", max_new_tokens: int = 40, temperature: float = 1.0, alpha: float = 0.5, ) -> Tuple[str, List[TokenEntropy], List[TokenEntropy]]: """ Generate text and return per-token entropy analysis. Returns ------- generated_text : full generated string shannon_results : TokenEntropy list using Shannon metric selection_results : TokenEntropy list using Selection Entropy """ model, tokenizer = load_model(model_name) input_ids = tokenizer.encode(prompt, return_tensors="pt") vocab_size = tokenizer.vocab_size # Build vocabulary token list once vocab_tokens = [tokenizer.decode([i]) for i in range(vocab_size)] generated_ids = [] all_logits = [] with torch.no_grad(): past_key_values = None current_ids = input_ids for _ in range(max_new_tokens): outputs = model( current_ids, past_key_values=past_key_values, use_cache=True, ) logits = outputs.logits[:, -1, :] # (1, vocab_size) past_key_values = outputs.past_key_values # Apply temperature scaled_logits = logits / max(temperature, 1e-6) probs = torch.softmax(scaled_logits, dim=-1) # Sample next token next_token = torch.multinomial(probs, num_samples=1) token_id = next_token.item() generated_ids.append(token_id) all_logits.append(scaled_logits[0].numpy()) current_ids = next_token # Stop at EOS if token_id == tokenizer.eos_token_id: break # Decode tokens individually for alignment tokens = [tokenizer.decode([tid]) for tid in generated_ids] logits_array = np.array(all_logits) # (seq_len, vocab_size) # Compute both entropy metrics results = compute_token_entropies( logits_sequence=logits_array, tokens=tokens, token_ids=generated_ids, vocab_tokens=vocab_tokens, alpha=alpha, top_k_display=5, ) generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) return generated_text, results def get_available_models() -> List[str]: return ["gpt2", "gpt2-medium"]