Spaces:
Runtime error
Runtime error
File size: 3,242 Bytes
9dbaed3 6b02199 9dbaed3 6b02199 9dbaed3 2b5df38 486c63e 9dbaed3 6b02199 9dbaed3 f1c5b50 6b02199 f1c5b50 6b02199 9dbaed3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import logging
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
import time
logger = logging.getLogger(__name__)
class ResponseGenerator:
def __init__(self, model_name="distilgpt2", cache_folder=None):
"""
Initialize the ResponseGenerator with a transformer model and tokenizer.
Args:
model_name (str): Name of the transformer model (default: 'distilgpt2').
cache_folder (str, optional): Directory to cache model files (default: None).
"""
logger.info(f"Initializing ResponseGenerator with model: {model_name}, cache_folder: {cache_folder}")
start_time = time.time()
try:
# Log cache contents for debugging
if cache_folder and os.path.exists(cache_folder):
logger.info(f"Cache folder contents: {os.listdir(cache_folder)}")
# Load tokenizer and model from cache
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_folder,
local_files_only=True
)
logger.info(f"Tokenizer loaded in {time.time() - start_time:.2f} seconds")
start_time = time.time()
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=cache_folder,
local_files_only=True
)
logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
except Exception as e:
logger.error(f"Failed to load transformer model: {str(e)}")
raise
logger.info("ResponseGenerator model loaded successfully")
def generate(self, user_message: str, context: List[Dict]) -> str:
"""
Generate a response based on the user message and retrieved context.
Args:
user_message (str): The user's input message.
context (List[Dict]): Retrieved documents for context.
Returns:
str: Generated response.
"""
logger.info(f"Generating response for user message: {user_message}")
try:
# Combine context and user message
context_text = " ".join([doc['content'] for doc in context])
input_text = f"Context: {context_text}\nUser: {user_message}\nBot:"
# Tokenize input
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate response
outputs = self.model.generate(
inputs["input_ids"],
max_length=100,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95
)
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info("Response generated successfully")
return response.split("Bot:")[-1].strip()
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return "Sorry, I couldn't generate a response." |