Spaces:
Running on Zero
Running on Zero
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import exp | |
| from typing import Any, Tuple, List | |
| def start_model(model_id: str = "google/gemma-4-31B-it"): | |
| ''' | |
| Initializes and returns the processor and model. | |
| ''' | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| dtype="auto", | |
| device_map="auto" | |
| ) | |
| print(f'Model {model_id} has been installed.') | |
| return model, processor | |
| def make_beams(model: AutoModelForCausalLM, processor: AutoProcessor, initial_prompt: str, temperature: float = 1.0) -> Tuple[Any, List[str]]: | |
| ''' | |
| Generates 3 diverse responses in response to a prompt. | |
| ''' | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": initial_prompt}, | |
| ] | |
| # Process input | |
| text = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False | |
| ) | |
| inputs = processor(text=text, return_tensors="pt").to(model.device) | |
| # Generate output | |
| generated_dicts = model.generate(**inputs, | |
| max_new_tokens=1024, | |
| num_beams=1, # Disable beam search for pure sampling | |
| num_return_sequences=3, # Generate 3 independent diverse samples | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| temperature=temperature if temperature > 0 else 0.1, # Ensure T > 0 for sampling | |
| top_p=0.9, # Nucleus sampling for high-quality diversity | |
| top_k=50, # Top-K sampling to filter noise | |
| do_sample=True) | |
| transcription = processor.batch_decode(generated_dicts.sequences, skip_special_tokens=True) | |
| print('Keys in model output -------------------') | |
| for key in generated_dicts: | |
| print(key) | |
| print('----------------------------------------') | |
| print('Beam scores ----------------------------') | |
| # sequences_scores is only present in beam search. | |
| # For sampling, we can approximate the overall score by averaging the transition probabilities. | |
| if hasattr(generated_dicts, 'sequences_scores'): | |
| for score in generated_dicts.sequences_scores: | |
| print(exp(score).item()) | |
| else: | |
| print('Sampling mode: sequences_scores not available.') | |
| print('----------------------------------------') | |
| return generated_dicts, transcription | |
| def parse_beams(transcription: List[str]) -> List[str]: | |
| ''' | |
| Parses beams to extract only the response after 'model\nthought'. | |
| ''' | |
| beam_text = [] | |
| for beam in transcription: | |
| parts = beam.split('''model\nthought''') | |
| if len(parts) > 1: | |
| response = parts[1].strip('\n') | |
| else: | |
| response = beam | |
| beam_text.append(response) | |
| print('Beams have been parsed. --------------') | |
| return beam_text | |
| def get_beam_tokens(generated_dicts: Any, processor: AutoProcessor) -> List[List[str]]: | |
| ''' | |
| Decodes the generated sequences into individual tokens for each beam. | |
| ''' | |
| beam_tokens = [] | |
| # The number of generated tokens is the length of the scores list | |
| gen_len = len(generated_dicts.scores) | |
| for sequence in generated_dicts.sequences: | |
| # Calculate input length for this specific sequence | |
| total_len = sequence.shape[0] | |
| input_len = total_len - gen_len | |
| # Extract only the generated token IDs | |
| generated_ids = sequence[input_len:].tolist() | |
| # Convert IDs to tokens (e.g., ' Hello', ' world') | |
| tokens = processor.tokenizer.convert_ids_to_tokens(generated_ids) | |
| beam_tokens.append(tokens) | |
| return beam_tokens | |
| def calculate_score_vectors(model: AutoModelForCausalLM, generated_dicts: Any) -> List[List[float]]: | |
| ''' | |
| Creates a score vector for each beam containing the probability of each | |
| token that was chosen. | |
| Optimized to use generated_dicts.scores instead of a full model forward pass | |
| to prevent GPU timeouts on ZeroGPU. | |
| ''' | |
| # Number of sequences generated | |
| num_sequences = generated_dicts.sequences.shape[0] | |
| # Number of generated tokens (excluding prompt) | |
| gen_len = len(generated_dicts.scores) | |
| # Total length of sequences (including prompt) | |
| total_len = generated_dicts.sequences.shape[1] | |
| # Input length (prompt length) | |
| input_len = total_len - gen_len | |
| # Stack the transition scores (logits) from the generation process | |
| # generated_dicts.scores is a tuple of length gen_len, each element (num_beams, vocab_size) | |
| all_logits = torch.stack(generated_dicts.scores, dim=0) # shape: (gen_len, num_sequences, vocab_size) | |
| # Convert logits to probabilities across the vocab dimension | |
| all_probs = F.softmax(all_logits, dim=-1) # shape: (gen_len, num_sequences, vocab_size) | |
| score_vectors = [] | |
| for i in range(num_sequences): | |
| beam_probs = [] | |
| # Extract the probability for the specific token that was chosen at each step | |
| for t in range(gen_len): | |
| token_id = generated_dicts.sequences[i, input_len + t] | |
| prob = all_probs[t, i, token_id].item() | |
| beam_probs.append(prob) | |
| score_vectors.append(beam_probs) | |
| return score_vectors | |