import logging import os from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List, Dict import time logger = logging.getLogger(__name__) class ResponseGenerator: def __init__(self, model_name="distilgpt2", cache_folder=None): """ Initialize the ResponseGenerator with a transformer model and tokenizer. Args: model_name (str): Name of the transformer model (default: 'distilgpt2'). cache_folder (str, optional): Directory to cache model files (default: None). """ logger.info(f"Initializing ResponseGenerator with model: {model_name}, cache_folder: {cache_folder}") start_time = time.time() try: # Log cache contents for debugging if cache_folder and os.path.exists(cache_folder): logger.info(f"Cache folder contents: {os.listdir(cache_folder)}") # Load tokenizer and model from cache self.tokenizer = AutoTokenizer.from_pretrained( model_name, cache_dir=cache_folder, local_files_only=True ) logger.info(f"Tokenizer loaded in {time.time() - start_time:.2f} seconds") start_time = time.time() self.model = AutoModelForCausalLM.from_pretrained( model_name, cache_dir=cache_folder, local_files_only=True ) logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds") except Exception as e: logger.error(f"Failed to load transformer model: {str(e)}") raise logger.info("ResponseGenerator model loaded successfully") def generate(self, user_message: str, context: List[Dict]) -> str: """ Generate a response based on the user message and retrieved context. Args: user_message (str): The user's input message. context (List[Dict]): Retrieved documents for context. Returns: str: Generated response. """ logger.info(f"Generating response for user message: {user_message}") try: # Combine context and user message context_text = " ".join([doc['content'] for doc in context]) input_text = f"Context: {context_text}\nUser: {user_message}\nBot:" # Tokenize input inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) # Generate response outputs = self.model.generate( inputs["input_ids"], max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, top_k=50, top_p=0.95 ) # Decode response response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info("Response generated successfully") return response.split("Bot:")[-1].strip() except Exception as e: logger.error(f"Error generating response: {str(e)}") return "Sorry, I couldn't generate a response."