selection-entropy-llm / model_inference.py
mioulin's picture
Rename model-inference.py to model_inference.py
521843a verified
"""
Model inference with logits extraction.
Uses GPT-2 (small/medium) — fully open, no API key required.
Runs inside HuggingFace Space CPU environment.
"""
import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from typing import Tuple, List, Optional
from selection_entropy import compute_token_entropies, TokenEntropy
# Singleton model cache
_model_cache = {}
def load_model(model_name: str = "gpt2") -> Tuple:
"""Load and cache model + tokenizer."""
if model_name not in _model_cache:
print(f"Loading {model_name}...")
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
_model_cache[model_name] = (model, tokenizer)
return _model_cache[model_name]
def generate_with_logits(
prompt: str,
model_name: str = "gpt2",
max_new_tokens: int = 40,
temperature: float = 1.0,
alpha: float = 0.5,
) -> Tuple[str, List[TokenEntropy], List[TokenEntropy]]:
"""
Generate text and return per-token entropy analysis.
Returns
-------
generated_text : full generated string
shannon_results : TokenEntropy list using Shannon metric
selection_results : TokenEntropy list using Selection Entropy
"""
model, tokenizer = load_model(model_name)
input_ids = tokenizer.encode(prompt, return_tensors="pt")
vocab_size = tokenizer.vocab_size
# Build vocabulary token list once
vocab_tokens = [tokenizer.decode([i]) for i in range(vocab_size)]
generated_ids = []
all_logits = []
with torch.no_grad():
past_key_values = None
current_ids = input_ids
for _ in range(max_new_tokens):
outputs = model(
current_ids,
past_key_values=past_key_values,
use_cache=True,
)
logits = outputs.logits[:, -1, :] # (1, vocab_size)
past_key_values = outputs.past_key_values
# Apply temperature
scaled_logits = logits / max(temperature, 1e-6)
probs = torch.softmax(scaled_logits, dim=-1)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
generated_ids.append(token_id)
all_logits.append(scaled_logits[0].numpy())
current_ids = next_token
# Stop at EOS
if token_id == tokenizer.eos_token_id:
break
# Decode tokens individually for alignment
tokens = [tokenizer.decode([tid]) for tid in generated_ids]
logits_array = np.array(all_logits) # (seq_len, vocab_size)
# Compute both entropy metrics
results = compute_token_entropies(
logits_sequence=logits_array,
tokens=tokens,
token_ids=generated_ids,
vocab_tokens=vocab_tokens,
alpha=alpha,
top_k_display=5,
)
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
return generated_text, results
def get_available_models() -> List[str]:
return ["gpt2", "gpt2-medium"]