from transformers import AutoProcessor, AutoModelForCausalLM import torch import torch.nn.functional as F from torch import exp from typing import Any, Tuple, List def start_model(model_id: str = "google/gemma-4-31B-it"): ''' Initializes and returns the processor and model. ''' processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, dtype="auto", device_map="auto" ) print(f'Model {model_id} has been installed.') return model, processor def make_beams(model: AutoModelForCausalLM, processor: AutoProcessor, initial_prompt: str, temperature: float = 1.0) -> Tuple[Any, List[str]]: ''' Generates 3 diverse responses in response to a prompt. ''' messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": initial_prompt}, ] # Process input text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) inputs = processor(text=text, return_tensors="pt").to(model.device) # Generate output generated_dicts = model.generate(**inputs, max_new_tokens=1024, num_beams=1, # Disable beam search for pure sampling num_return_sequences=3, # Generate 3 independent diverse samples return_dict_in_generate=True, output_scores=True, temperature=temperature if temperature > 0 else 0.1, # Ensure T > 0 for sampling top_p=0.9, # Nucleus sampling for high-quality diversity top_k=50, # Top-K sampling to filter noise do_sample=True) transcription = processor.batch_decode(generated_dicts.sequences, skip_special_tokens=True) print('Keys in model output -------------------') for key in generated_dicts: print(key) print('----------------------------------------') print('Beam scores ----------------------------') # sequences_scores is only present in beam search. # For sampling, we can approximate the overall score by averaging the transition probabilities. if hasattr(generated_dicts, 'sequences_scores'): for score in generated_dicts.sequences_scores: print(exp(score).item()) else: print('Sampling mode: sequences_scores not available.') print('----------------------------------------') return generated_dicts, transcription def parse_beams(transcription: List[str]) -> List[str]: ''' Parses beams to extract only the response after 'model\nthought'. ''' beam_text = [] for beam in transcription: parts = beam.split('''model\nthought''') if len(parts) > 1: response = parts[1].strip('\n') else: response = beam beam_text.append(response) print('Beams have been parsed. --------------') return beam_text def get_beam_tokens(generated_dicts: Any, processor: AutoProcessor) -> List[List[str]]: ''' Decodes the generated sequences into individual tokens for each beam. ''' beam_tokens = [] # The number of generated tokens is the length of the scores list gen_len = len(generated_dicts.scores) for sequence in generated_dicts.sequences: # Calculate input length for this specific sequence total_len = sequence.shape[0] input_len = total_len - gen_len # Extract only the generated token IDs generated_ids = sequence[input_len:].tolist() # Convert IDs to tokens (e.g., ' Hello', ' world') tokens = processor.tokenizer.convert_ids_to_tokens(generated_ids) beam_tokens.append(tokens) return beam_tokens def calculate_score_vectors(model: AutoModelForCausalLM, generated_dicts: Any) -> List[List[float]]: ''' Creates a score vector for each beam containing the probability of each token that was chosen. Optimized to use generated_dicts.scores instead of a full model forward pass to prevent GPU timeouts on ZeroGPU. ''' # Number of sequences generated num_sequences = generated_dicts.sequences.shape[0] # Number of generated tokens (excluding prompt) gen_len = len(generated_dicts.scores) # Total length of sequences (including prompt) total_len = generated_dicts.sequences.shape[1] # Input length (prompt length) input_len = total_len - gen_len # Stack the transition scores (logits) from the generation process # generated_dicts.scores is a tuple of length gen_len, each element (num_beams, vocab_size) all_logits = torch.stack(generated_dicts.scores, dim=0) # shape: (gen_len, num_sequences, vocab_size) # Convert logits to probabilities across the vocab dimension all_probs = F.softmax(all_logits, dim=-1) # shape: (gen_len, num_sequences, vocab_size) score_vectors = [] for i in range(num_sequences): beam_probs = [] # Extract the probability for the specific token that was chosen at each step for t in range(gen_len): token_id = generated_dicts.sequences[i, input_len + t] prob = all_probs[t, i, token_id].item() beam_probs.append(prob) score_vectors.append(beam_probs) return score_vectors