File size: 3,242 Bytes
9dbaed3
6b02199
9dbaed3
 
6b02199
9dbaed3
 
2b5df38
486c63e
9dbaed3
 
 
 
 
 
 
 
 
6b02199
9dbaed3
f1c5b50
 
 
 
 
 
 
 
 
6b02199
 
f1c5b50
 
 
 
 
6b02199
9dbaed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import logging
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
import time

logger = logging.getLogger(__name__)

class ResponseGenerator:
    def __init__(self, model_name="distilgpt2", cache_folder=None):
        """
        Initialize the ResponseGenerator with a transformer model and tokenizer.

        Args:
            model_name (str): Name of the transformer model (default: 'distilgpt2').
            cache_folder (str, optional): Directory to cache model files (default: None).
        """
        logger.info(f"Initializing ResponseGenerator with model: {model_name}, cache_folder: {cache_folder}")
        start_time = time.time()
        try:
            # Log cache contents for debugging
            if cache_folder and os.path.exists(cache_folder):
                logger.info(f"Cache folder contents: {os.listdir(cache_folder)}")
            # Load tokenizer and model from cache
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                cache_dir=cache_folder,
                local_files_only=True
            )
            logger.info(f"Tokenizer loaded in {time.time() - start_time:.2f} seconds")
            start_time = time.time()
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                cache_dir=cache_folder,
                local_files_only=True
            )
            logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
        except Exception as e:
            logger.error(f"Failed to load transformer model: {str(e)}")
            raise
        logger.info("ResponseGenerator model loaded successfully")

    def generate(self, user_message: str, context: List[Dict]) -> str:
        """
        Generate a response based on the user message and retrieved context.

        Args:
            user_message (str): The user's input message.
            context (List[Dict]): Retrieved documents for context.

        Returns:
            str: Generated response.
        """
        logger.info(f"Generating response for user message: {user_message}")
        try:
            # Combine context and user message
            context_text = " ".join([doc['content'] for doc in context])
            input_text = f"Context: {context_text}\nUser: {user_message}\nBot:"
            
            # Tokenize input
            inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
            
            # Generate response
            outputs = self.model.generate(
                inputs["input_ids"],
                max_length=100,
                num_return_sequences=1,
                no_repeat_ngram_size=2,
                do_sample=True,
                top_k=50,
                top_p=0.95
            )
            
            # Decode response
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            logger.info("Response generated successfully")
            return response.split("Bot:")[-1].strip()
        except Exception as e:
            logger.error(f"Error generating response: {str(e)}")
            return "Sorry, I couldn't generate a response."