ayush2917 commited on
Commit
9dbaed3
·
verified ·
1 Parent(s): 3c3d26a

Update src/generation.py

Browse files
Files changed (1) hide show
  1. src/generation.py +60 -8
src/generation.py CHANGED
@@ -1,10 +1,62 @@
1
- from transformers import pipeline
 
 
 
 
2
 
3
  class ResponseGenerator:
4
- def __init__(self, model_name='distilgpt2'):
5
- self.generator = pipeline('text-generation', model=model_name)
6
-
7
- def generate(self, prompt: str, context: list, max_length: int = 150):
8
- context_str = "\n".join([doc['content'] for doc in context])
9
- full_prompt = f"Context: {context_str}\nQuestion: {prompt}\nAnswer:"
10
- return self.generator(full_prompt, max_length=max_length, num_return_sequences=1)[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from typing import List, Dict
4
+
5
+ logger = logging.getLogger(__name__)
6
 
7
  class ResponseGenerator:
8
+ def __init__(self, model_name="distilgpt2", cache_folder=None):
9
+ """
10
+ Initialize the ResponseGenerator with a transformer model and tokenizer.
11
+
12
+ Args:
13
+ model_name (str): Name of the transformer model (default: 'distilgpt2').
14
+ cache_folder (str, optional): Directory to cache model files (default: None).
15
+ """
16
+ logger.info(f"Initializing ResponseGenerator with model: {model_name}, cache_folder: {cache_folder}")
17
+ try:
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_folder)
19
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_folder)
20
+ except Exception as e:
21
+ logger.error(f"Failed to load transformer model: {str(e)}")
22
+ raise
23
+ logger.info("ResponseGenerator model loaded successfully")
24
+
25
+ def generate(self, user_message: str, context: List[Dict]) -> str:
26
+ """
27
+ Generate a response based on the user message and retrieved context.
28
+
29
+ Args:
30
+ user_message (str): The user's input message.
31
+ context (List[Dict]): Retrieved documents for context.
32
+
33
+ Returns:
34
+ str: Generated response.
35
+ """
36
+ logger.info(f"Generating response for user message: {user_message}")
37
+ try:
38
+ # Combine context and user message
39
+ context_text = " ".join([doc['content'] for doc in context])
40
+ input_text = f"Context: {context_text}\nUser: {user_message}\nBot:"
41
+
42
+ # Tokenize input
43
+ inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
44
+
45
+ # Generate response
46
+ outputs = self.model.generate(
47
+ inputs["input_ids"],
48
+ max_length=100,
49
+ num_return_sequences=1,
50
+ no_repeat_ngram_size=2,
51
+ do_sample=True,
52
+ top_k=50,
53
+ top_p=0.95
54
+ )
55
+
56
+ # Decode response
57
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
58
+ logger.info("Response generated successfully")
59
+ return response.split("Bot:")[-1].strip()
60
+ except Exception as e:
61
+ logger.error(f"Error generating response: {str(e)}")
62
+ return "Sorry, I couldn't generate a response."