Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import List, Dict | |
| import time | |
| logger = logging.getLogger(__name__) | |
| class ResponseGenerator: | |
| def __init__(self, model_name="distilgpt2", cache_folder=None): | |
| """ | |
| Initialize the ResponseGenerator with a transformer model and tokenizer. | |
| Args: | |
| model_name (str): Name of the transformer model (default: 'distilgpt2'). | |
| cache_folder (str, optional): Directory to cache model files (default: None). | |
| """ | |
| logger.info(f"Initializing ResponseGenerator with model: {model_name}, cache_folder: {cache_folder}") | |
| start_time = time.time() | |
| try: | |
| # Log cache contents for debugging | |
| if cache_folder and os.path.exists(cache_folder): | |
| logger.info(f"Cache folder contents: {os.listdir(cache_folder)}") | |
| # Load tokenizer and model from cache | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| cache_dir=cache_folder, | |
| local_files_only=True | |
| ) | |
| logger.info(f"Tokenizer loaded in {time.time() - start_time:.2f} seconds") | |
| start_time = time.time() | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| cache_dir=cache_folder, | |
| local_files_only=True | |
| ) | |
| logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds") | |
| except Exception as e: | |
| logger.error(f"Failed to load transformer model: {str(e)}") | |
| raise | |
| logger.info("ResponseGenerator model loaded successfully") | |
| def generate(self, user_message: str, context: List[Dict]) -> str: | |
| """ | |
| Generate a response based on the user message and retrieved context. | |
| Args: | |
| user_message (str): The user's input message. | |
| context (List[Dict]): Retrieved documents for context. | |
| Returns: | |
| str: Generated response. | |
| """ | |
| logger.info(f"Generating response for user message: {user_message}") | |
| try: | |
| # Combine context and user message | |
| context_text = " ".join([doc['content'] for doc in context]) | |
| input_text = f"Context: {context_text}\nUser: {user_message}\nBot:" | |
| # Tokenize input | |
| inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
| # Generate response | |
| outputs = self.model.generate( | |
| inputs["input_ids"], | |
| max_length=100, | |
| num_return_sequences=1, | |
| no_repeat_ngram_size=2, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95 | |
| ) | |
| # Decode response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| logger.info("Response generated successfully") | |
| return response.split("Bot:")[-1].strip() | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| return "Sorry, I couldn't generate a response." |