Ahmed-Alghamdi commited on
Commit
a7f3645
·
verified ·
1 Parent(s): 2bc67a7

Update response_generator.py

Browse files
Files changed (1) hide show
  1. response_generator.py +51 -40
response_generator.py CHANGED
@@ -1,6 +1,6 @@
1
  # response_generator.py
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from utils import setup_logger
5
  from config import Config
6
 
@@ -8,47 +8,65 @@ logger = setup_logger('response_generator')
8
 
9
  class ResponseGenerator:
10
  def __init__(self):
 
 
11
  self.tokenizer = AutoTokenizer.from_pretrained(Config.LLM_MODEL)
12
- self.model = AutoModelForCausalLM.from_pretrained(Config.LLM_MODEL)
13
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- self.model.to(self.device)
15
- logger.info(f"Model loaded and moved to {self.device}")
16
-
17
  def generate_response(self, query, relevant_docs):
18
  try:
19
- context = self._prepare_context(relevant_docs)
20
- prompt = self._create_prompt(query, context)
21
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
22
 
23
- attention_mask = input_ids.ne(self.tokenizer.pad_token_id).float()
24
-
25
- with torch.no_grad():
26
- output = self.model.generate(
27
- input_ids,
28
- attention_mask=attention_mask,
29
- max_length=Config.MAX_LENGTH,
30
- num_return_sequences=1,
31
- no_repeat_ngram_size=2,
32
- do_sample=True,
33
- top_k=50,
34
- top_p=0.95,
35
- temperature=0.7
36
- )
37
 
38
- response = self.tokenizer.decode(output[0], skip_special_tokens=True)
39
- return self._extract_answer(response)
40
-
 
 
41
  except Exception as e:
42
  logger.error(f"Error generating response: {e}")
43
- return "عذرًا، لم أتمكن من إنشاء استجابة بسبب خطأ ما." # "Sorry, I couldn't generate a response due to an error."
44
-
45
  def _prepare_context(self, relevant_docs):
46
- # Combine content from relevant documents
47
- combined_content = "\n".join(relevant_docs['content'].tolist())
48
- # Truncate if too long
49
- max_context_length = Config.MAX_LENGTH // 2 # Use half of max_length for context
 
 
50
  return combined_content[:max_context_length]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
 
 
 
52
  def _create_prompt(self, query, context):
53
  return f"""مستند قانوني:
54
  {context}
@@ -57,17 +75,10 @@ class ResponseGenerator:
57
  {query}
58
 
59
  إجابة:"""
60
-
61
- def _extract_answer(self, response):
62
- # Extract the generated answer from the full response
63
- answer_start = response.find("إجابة:") + len("إجابة:")
64
- return response[answer_start:].strip()
65
-
66
  def update_model(self, new_model_name):
67
  try:
68
  self.tokenizer = AutoTokenizer.from_pretrained(new_model_name)
69
- self.model = AutoModelForCausalLM.from_pretrained(new_model_name)
70
- self.model.to(self.device)
71
  logger.info(f"Model updated to {new_model_name}")
72
  except Exception as e:
73
  logger.error(f"Error updating model: {e}")
 
1
  # response_generator.py
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from utils import setup_logger
5
  from config import Config
6
 
 
8
 
9
  class ResponseGenerator:
10
  def __init__(self):
11
+ # Use a simpler approach with a summarization/QA pipeline
12
+ # Since BERT-based models don't generate text, we'll create a simple retrieval-based response
13
  self.tokenizer = AutoTokenizer.from_pretrained(Config.LLM_MODEL)
14
+ logger.info(f"Tokenizer loaded from {Config.LLM_MODEL}")
15
+
 
 
 
16
  def generate_response(self, query, relevant_docs):
17
  try:
18
+ if len(relevant_docs) == 0:
19
+ return "عذرًا، لم أجد أي معلومات ذات صلة في المستندات."
 
20
 
21
+ # Get the most relevant document (first one)
22
+ context = self._prepare_context(relevant_docs)
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # For BERT-based models, we do extractive QA instead of generation
25
+ # Return the most relevant context as the answer
26
+ response = self._create_extractive_answer(query, context, relevant_docs)
27
+ return response
28
+
29
  except Exception as e:
30
  logger.error(f"Error generating response: {e}")
31
+ return "عذرًا، لم أتمكن من إنشاء استجابة بسبب خطأ ما."
32
+
33
  def _prepare_context(self, relevant_docs):
34
+ # Take only the top 3 most relevant documents to avoid token limit
35
+ top_docs = relevant_docs.head(3)
36
+ combined_content = "\n\n".join(top_docs['content'].tolist())
37
+
38
+ # Limit to 300 characters to stay within token limits
39
+ max_context_length = 300
40
  return combined_content[:max_context_length]
41
+
42
+ def _create_extractive_answer(self, query, context, relevant_docs):
43
+ """
44
+ Create an answer by extracting relevant information from documents
45
+ """
46
+ # Get the most relevant document
47
+ most_relevant = relevant_docs.iloc[0]['content']
48
+
49
+ # Truncate to reasonable length
50
+ max_length = 500
51
+ if len(most_relevant) > max_length:
52
+ # Try to find a good sentence break
53
+ truncated = most_relevant[:max_length]
54
+ last_period = truncated.rfind('.')
55
+ if last_period > 0:
56
+ most_relevant = truncated[:last_period + 1]
57
+ else:
58
+ most_relevant = truncated + "..."
59
+
60
+ # Format the response
61
+ response = f"""بناءً على المستندات المتاحة:
62
+
63
+ {most_relevant}
64
 
65
+ ---
66
+ المصدر: {relevant_docs.iloc[0]['path']}"""
67
+
68
+ return response
69
+
70
  def _create_prompt(self, query, context):
71
  return f"""مستند قانوني:
72
  {context}
 
75
  {query}
76
 
77
  إجابة:"""
78
+
 
 
 
 
 
79
  def update_model(self, new_model_name):
80
  try:
81
  self.tokenizer = AutoTokenizer.from_pretrained(new_model_name)
 
 
82
  logger.info(f"Model updated to {new_model_name}")
83
  except Exception as e:
84
  logger.error(f"Error updating model: {e}")