llmhallucination / model_loader.py
abhignya99's picture
deploy app
bad01a8
"""
Model Loader Module
Loads GPT-2 and GPT-Neo models using TransformerLens and provides text generation functionality.
"""
import torch
from transformer_lens import HookedTransformer
from typing import List, Dict, Any
import numpy as np
class GPT2ModelLoader:
"""
Handles loading and text generation with GPT-2 and GPT-Neo models using TransformerLens.
Supports: gpt2, gpt2-medium, gpt2-large, gpt2-xl,
EleutherAI/gpt-neo-125M, EleutherAI/gpt-neo-1.3B, EleutherAI/gpt-neo-2.7B
"""
def __init__(self, model_name: str = "gpt2"):
"""
Initialize the model.
Args:
model_name: Name of the GPT-2 or GPT-Neo model variant to load
"""
print(f"Loading {model_name} model...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = HookedTransformer.from_pretrained(model_name, device=self.device)
self.model_name = model_name
print(f"Model loaded successfully on {self.device}")
def generate_responses(
self,
prompt: str,
num_responses: int = 5,
max_length: int = 50,
temperature: float = 0.8,
top_p: float = 0.9
) -> List[str]:
"""
Generate multiple stochastic responses for a given prompt.
Args:
prompt: Input text prompt
num_responses: Number of responses to generate
max_length: Maximum length of generated text
temperature: Sampling temperature for diversity
top_p: Nucleus sampling parameter
Returns:
List of generated text responses
"""
responses = []
# Compute prompt token length once
prompt_tokens = self.model.to_tokens(prompt)
prompt_token_len = prompt_tokens.shape[1]
for i in range(num_responses):
# Generate text
generated_tokens = self.model.generate(
prompt_tokens,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
stop_at_eos=True
)
# Decode ONLY the newly generated tokens (not the prompt)
new_tokens = generated_tokens[0][prompt_token_len:]
generated_text = self.model.to_string(new_tokens).lstrip()
responses.append(generated_text)
print(f"Generated response {i+1}/{num_responses}")
return responses
def generate_with_cache(
self,
prompt: str,
max_length: int = 50,
temperature: float = 0.8,
top_p: float = 0.9
) -> Dict[str, Any]:
"""
Generate text and return both the text and model activations.
Args:
prompt: Input text prompt
max_length: Maximum length of generated text
temperature: Sampling temperature
top_p: Nucleus sampling parameter
Returns:
Dictionary containing generated text, tokens, logits, and cache
"""
# Tokenize the prompt
tokens = self.model.to_tokens(prompt)
prompt_length = tokens.shape[1]
# Generate with caching enabled
with torch.no_grad():
generated_tokens = self.model.generate(
tokens,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
stop_at_eos=True,
return_type="tokens"
)
# Get full sequence
full_tokens = generated_tokens[0]
# Run forward pass to get activations
with torch.no_grad():
logits, cache = self.model.run_with_cache(full_tokens)
# Decode ONLY the newly generated tokens (not the echoed prompt)
new_tokens = full_tokens[prompt_length:]
generated_text = self.model.to_string(new_tokens).lstrip()
return {
"text": generated_text,
"tokens": full_tokens,
"logits": logits,
"cache": cache,
"prompt_length": prompt_length
}
def get_model(self) -> HookedTransformer:
"""Return the underlying model."""
return self.model