import torch from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.cache_utils import DynamicCache import streamlit as st # Add necessary serialization safety torch.serialization.add_safe_globals([DynamicCache]) torch.serialization.add_safe_globals([set]) # Minimal generate function for token-by-token generation def generate(model, input_ids, past_key_values, max_new_tokens=50): """ This function performs token-by-token text generation using a pre-trained language model. Purpose: To generate new text based on input tokens, without loading the full context repeatedly Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens """ device = model.model.embed_tokens.weight.device origin_len = input_ids.shape[-1] input_ids = input_ids.to(device) output_ids = input_ids.clone() next_token = input_ids with torch.no_grad(): for _ in range(max_new_tokens): out = model( input_ids=next_token, past_key_values=past_key_values, use_cache=True ) logits = out.logits[:, -1, :] token = torch.argmax(logits, dim=-1, keepdim=True) output_ids = torch.cat([output_ids, token], dim=-1) past_key_values = out.past_key_values next_token = token.to(device) if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id: break return output_ids[:, origin_len:] # Return just the newly generated part def get_kv_cache(model, tokenizer, prompt): """ This function creates a key-value cache for a given prompt. Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt Process: Encodes the prompt, runs it through the model, and captures the resulting cache Returns: The cache object and the original prompt length for future reference """ # Encode prompt device = model.model.embed_tokens.weight.device input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) cache = DynamicCache() # it grows as text is generated # Run the model to populate the KV cache: with torch.no_grad(): _ = model( input_ids=input_ids, past_key_values=cache, use_cache=True ) return cache, input_ids.shape[-1] # Initialize session state for the model, tokenizer and cache @st.cache_resource def load_model_and_tokenizer(): model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", trust_remote_code=True ) return model, tokenizer