Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers.cache_utils import DynamicCache | |
| import streamlit as st | |
| # Add necessary serialization safety | |
| torch.serialization.add_safe_globals([DynamicCache]) | |
| torch.serialization.add_safe_globals([set]) | |
| # Minimal generate function for token-by-token generation | |
| def generate(model, input_ids, past_key_values, max_new_tokens=50): | |
| """ | |
| This function performs token-by-token text generation using a pre-trained language model. | |
| Purpose: To generate new text based on input tokens, without loading the full context repeatedly | |
| Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum | |
| Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens | |
| """ | |
| device = model.model.embed_tokens.weight.device | |
| origin_len = input_ids.shape[-1] | |
| input_ids = input_ids.to(device) | |
| output_ids = input_ids.clone() | |
| next_token = input_ids | |
| with torch.no_grad(): | |
| for _ in range(max_new_tokens): | |
| out = model( | |
| input_ids=next_token, | |
| past_key_values=past_key_values, | |
| use_cache=True | |
| ) | |
| logits = out.logits[:, -1, :] | |
| token = torch.argmax(logits, dim=-1, keepdim=True) | |
| output_ids = torch.cat([output_ids, token], dim=-1) | |
| past_key_values = out.past_key_values | |
| next_token = token.to(device) | |
| if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id: | |
| break | |
| return output_ids[:, origin_len:] # Return just the newly generated part | |
| def get_kv_cache(model, tokenizer, prompt): | |
| """ | |
| This function creates a key-value cache for a given prompt. | |
| Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt | |
| Process: Encodes the prompt, runs it through the model, and captures the resulting cache | |
| Returns: The cache object and the original prompt length for future reference | |
| """ | |
| # Encode prompt | |
| device = model.model.embed_tokens.weight.device | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| cache = DynamicCache() # it grows as text is generated | |
| # Run the model to populate the KV cache: | |
| with torch.no_grad(): | |
| _ = model( | |
| input_ids=input_ids, | |
| past_key_values=cache, | |
| use_cache=True | |
| ) | |
| return cache, input_ids.shape[-1] | |
| # Initialize session state for the model, tokenizer and cache | |
| def load_model_and_tokenizer(): | |
| model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| return model, tokenizer |