InflectionLM / inflections_funcs.py
cafierom's picture
Upload 2 files
360bdcd verified
Raw
History Blame Contribute Delete
5.58 kB
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
import torch.nn.functional as F
from torch import exp
from typing import Any, Tuple, List
def start_model(model_id: str = "google/gemma-4-31B-it"):
'''
Initializes and returns the processor and model.
'''
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype="auto",
device_map="auto"
)
print(f'Model {model_id} has been installed.')
return model, processor
def make_beams(model: AutoModelForCausalLM, processor: AutoProcessor, initial_prompt: str, temperature: float = 1.0) -> Tuple[Any, List[str]]:
'''
Generates 3 diverse responses in response to a prompt.
'''
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": initial_prompt},
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
# Generate output
generated_dicts = model.generate(**inputs,
max_new_tokens=1024,
num_beams=1, # Disable beam search for pure sampling
num_return_sequences=3, # Generate 3 independent diverse samples
return_dict_in_generate=True,
output_scores=True,
temperature=temperature if temperature > 0 else 0.1, # Ensure T > 0 for sampling
top_p=0.9, # Nucleus sampling for high-quality diversity
top_k=50, # Top-K sampling to filter noise
do_sample=True)
transcription = processor.batch_decode(generated_dicts.sequences, skip_special_tokens=True)
print('Keys in model output -------------------')
for key in generated_dicts:
print(key)
print('----------------------------------------')
print('Beam scores ----------------------------')
# sequences_scores is only present in beam search.
# For sampling, we can approximate the overall score by averaging the transition probabilities.
if hasattr(generated_dicts, 'sequences_scores'):
for score in generated_dicts.sequences_scores:
print(exp(score).item())
else:
print('Sampling mode: sequences_scores not available.')
print('----------------------------------------')
return generated_dicts, transcription
def parse_beams(transcription: List[str]) -> List[str]:
'''
Parses beams to extract only the response after 'model\nthought'.
'''
beam_text = []
for beam in transcription:
parts = beam.split('''model\nthought''')
if len(parts) > 1:
response = parts[1].strip('\n')
else:
response = beam
beam_text.append(response)
print('Beams have been parsed. --------------')
return beam_text
def get_beam_tokens(generated_dicts: Any, processor: AutoProcessor) -> List[List[str]]:
'''
Decodes the generated sequences into individual tokens for each beam.
'''
beam_tokens = []
# The number of generated tokens is the length of the scores list
gen_len = len(generated_dicts.scores)
for sequence in generated_dicts.sequences:
# Calculate input length for this specific sequence
total_len = sequence.shape[0]
input_len = total_len - gen_len
# Extract only the generated token IDs
generated_ids = sequence[input_len:].tolist()
# Convert IDs to tokens (e.g., ' Hello', ' world')
tokens = processor.tokenizer.convert_ids_to_tokens(generated_ids)
beam_tokens.append(tokens)
return beam_tokens
def calculate_score_vectors(model: AutoModelForCausalLM, generated_dicts: Any) -> List[List[float]]:
'''
Creates a score vector for each beam containing the probability of each
token that was chosen.
Optimized to use generated_dicts.scores instead of a full model forward pass
to prevent GPU timeouts on ZeroGPU.
'''
# Number of sequences generated
num_sequences = generated_dicts.sequences.shape[0]
# Number of generated tokens (excluding prompt)
gen_len = len(generated_dicts.scores)
# Total length of sequences (including prompt)
total_len = generated_dicts.sequences.shape[1]
# Input length (prompt length)
input_len = total_len - gen_len
# Stack the transition scores (logits) from the generation process
# generated_dicts.scores is a tuple of length gen_len, each element (num_beams, vocab_size)
all_logits = torch.stack(generated_dicts.scores, dim=0) # shape: (gen_len, num_sequences, vocab_size)
# Convert logits to probabilities across the vocab dimension
all_probs = F.softmax(all_logits, dim=-1) # shape: (gen_len, num_sequences, vocab_size)
score_vectors = []
for i in range(num_sequences):
beam_probs = []
# Extract the probability for the specific token that was chosen at each step
for t in range(gen_len):
token_id = generated_dicts.sequences[i, input_len + t]
prob = all_probs[t, i, token_id].item()
beam_probs.append(prob)
score_vectors.append(beam_probs)
return score_vectors