support-system / src /generation.py
ayush2917's picture
Update src/generation.py
6b02199 verified
import logging
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
import time
logger = logging.getLogger(__name__)
class ResponseGenerator:
def __init__(self, model_name="distilgpt2", cache_folder=None):
"""
Initialize the ResponseGenerator with a transformer model and tokenizer.
Args:
model_name (str): Name of the transformer model (default: 'distilgpt2').
cache_folder (str, optional): Directory to cache model files (default: None).
"""
logger.info(f"Initializing ResponseGenerator with model: {model_name}, cache_folder: {cache_folder}")
start_time = time.time()
try:
# Log cache contents for debugging
if cache_folder and os.path.exists(cache_folder):
logger.info(f"Cache folder contents: {os.listdir(cache_folder)}")
# Load tokenizer and model from cache
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_folder,
local_files_only=True
)
logger.info(f"Tokenizer loaded in {time.time() - start_time:.2f} seconds")
start_time = time.time()
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=cache_folder,
local_files_only=True
)
logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
except Exception as e:
logger.error(f"Failed to load transformer model: {str(e)}")
raise
logger.info("ResponseGenerator model loaded successfully")
def generate(self, user_message: str, context: List[Dict]) -> str:
"""
Generate a response based on the user message and retrieved context.
Args:
user_message (str): The user's input message.
context (List[Dict]): Retrieved documents for context.
Returns:
str: Generated response.
"""
logger.info(f"Generating response for user message: {user_message}")
try:
# Combine context and user message
context_text = " ".join([doc['content'] for doc in context])
input_text = f"Context: {context_text}\nUser: {user_message}\nBot:"
# Tokenize input
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate response
outputs = self.model.generate(
inputs["input_ids"],
max_length=100,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95
)
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info("Response generated successfully")
return response.split("Bot:")[-1].strip()
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return "Sorry, I couldn't generate a response."