fifth_try_CAG / model_utils.py
kouki321's picture
Create model_utils.py
6d0b18e verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import streamlit as st
# Add necessary serialization safety
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])
# Minimal generate function for token-by-token generation
def generate(model, input_ids, past_key_values, max_new_tokens=50):
"""
This function performs token-by-token text generation using a pre-trained language model.
Purpose: To generate new text based on input tokens, without loading the full context repeatedly
Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum
Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens
"""
device = model.model.embed_tokens.weight.device
origin_len = input_ids.shape[-1]
input_ids = input_ids.to(device)
output_ids = input_ids.clone()
next_token = input_ids
with torch.no_grad():
for _ in range(max_new_tokens):
out = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
logits = out.logits[:, -1, :]
token = torch.argmax(logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, token], dim=-1)
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
return output_ids[:, origin_len:] # Return just the newly generated part
def get_kv_cache(model, tokenizer, prompt):
"""
This function creates a key-value cache for a given prompt.
Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt
Process: Encodes the prompt, runs it through the model, and captures the resulting cache
Returns: The cache object and the original prompt length for future reference
"""
# Encode prompt
device = model.model.embed_tokens.weight.device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
cache = DynamicCache() # it grows as text is generated
# Run the model to populate the KV cache:
with torch.no_grad():
_ = model(
input_ids=input_ids,
past_key_values=cache,
use_cache=True
)
return cache, input_ids.shape[-1]
# Initialize session state for the model, tokenizer and cache
@st.cache_resource
def load_model_and_tokenizer():
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True
)
return model, tokenizer