midrees2806 commited on
Commit
729f39a
·
verified ·
1 Parent(s): 1d3116d

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +188 -51
rag.py CHANGED
@@ -1,149 +1,286 @@
1
  import json
2
- import glob
3
- import os
4
- import pandas as pd
5
- from datetime import datetime
6
- from dotenv import load_dotenv
7
 
8
- # AI and Data Libraries
9
  from sentence_transformers import SentenceTransformer, util
 
10
  from groq import Groq
 
 
 
 
 
 
 
11
  from datasets import load_dataset, Dataset
12
 
13
- # Image and Utility Libraries
14
- import requests
15
- from io import BytesIO
16
- from PIL import Image, ImageDraw, ImageFont
17
- import numpy as np
 
 
18
 
19
  # Load environment variables
 
20
  load_dotenv()
21
 
 
 
22
  # Initialize Groq client
 
23
  groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
24
 
25
- # Load models and dataset
 
 
 
26
  similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
27
 
 
 
28
  # Config
 
29
  HF_DATASET_REPO = "midrees2806/unmatched_queries"
 
30
  HF_TOKEN = os.getenv("HF_TOKEN")
31
 
32
- # Load multiple JSON datasets from the 'datasets' folder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  dataset = []
 
34
  try:
35
- # Using glob to find all json files in the folder
36
  json_files = glob.glob('datasets/*.json')
 
37
  for file_path in json_files:
 
38
  with open(file_path, 'r', encoding='utf-8') as f:
 
39
  data = json.load(f)
 
40
  if isinstance(data, list):
 
41
  for item in data:
 
42
  if isinstance(item, dict) and 'Question' in item and 'Answer' in item:
 
43
  dataset.append(item)
 
 
 
 
 
44
  else:
45
- print(f"Skipping {file_path}: File does not contain a list.")
 
 
46
  except Exception as e:
 
47
  print(f"Error loading datasets: {e}")
48
 
 
 
49
  # Precompute embeddings
 
50
  dataset_questions = [item.get("Question", "").lower().strip() for item in dataset]
 
51
  dataset_answers = [item.get("Answer", "") for item in dataset]
52
 
53
- if dataset_questions:
54
- dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
55
- else:
56
- dataset_embeddings = None
57
- print("Warning: No data found in the datasets folder.")
58
 
59
  # Save unmatched queries to Hugging Face
 
60
  def manage_unmatched_queries(query: str):
 
61
  try:
 
62
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
63
  try:
 
64
  ds = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)
 
65
  df = ds["train"].to_pandas()
 
66
  except:
 
67
  df = pd.DataFrame(columns=["Query", "Timestamp", "Processed"])
68
-
69
  if query not in df["Query"].values:
 
70
  new_entry = {"Query": query, "Timestamp": timestamp, "Processed": False}
 
71
  df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)
 
72
  updated_ds = Dataset.from_pandas(df)
 
73
  updated_ds.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN)
 
74
  except Exception as e:
 
75
  print(f"Failed to save query: {e}")
76
 
 
 
77
  # Query Groq LLM
 
78
  def query_groq_llm(prompt, model_name="llama3-70b-8192"):
 
79
  try:
 
80
  chat_completion = groq_client.chat.completions.create(
 
81
  messages=[{
 
82
  "role": "user",
 
83
  "content": prompt
 
84
  }],
 
85
  model=model_name,
 
86
  temperature=0.7,
 
87
  max_tokens=500
 
88
  )
 
89
  return chat_completion.choices[0].message.content.strip()
 
90
  except Exception as e:
 
91
  print(f"Error querying Groq API: {e}")
 
92
  return ""
93
 
94
- # Main logic function
 
 
 
95
  def get_best_answer(user_input):
 
96
  if not user_input.strip():
 
97
  return "Please enter a valid question."
98
-
 
 
99
  user_input_lower = user_input.lower().strip()
100
 
101
- # 👉 Check if question is about fee
102
- if any(keyword in user_input_lower for keyword in ["fee", "fees", "charges", "semester fee"]):
 
 
 
 
 
 
 
 
103
  return (
 
104
  "💰 For complete and up-to-date fee details for this program, we recommend visiting the official University of Education fee structure page.\n"
105
- "You’ll find comprehensive information regarding tuition, admission charges, and other applicable fees there.\n"
 
 
106
  "🔗 https://ue.edu.pk/allfeestructure.php"
 
107
  )
108
 
109
- # 🔁 Continue with normal similarity-based logic
110
- if dataset_embeddings is None:
111
- return "I am currently updating my database. Please try again in a moment."
112
 
113
  user_embedding = similarity_model.encode(user_input_lower, convert_to_tensor=True)
 
114
  similarities = util.pytorch_cos_sim(user_embedding, dataset_embeddings)[0]
 
115
  best_match_idx = similarities.argmax().item()
 
116
  best_score = similarities[best_match_idx].item()
117
 
 
 
118
  if best_score < 0.65:
 
119
  manage_unmatched_queries(user_input)
120
 
121
- if best_score >= 0.65:
122
- original_answer = dataset_answers[best_match_idx]
123
- prompt = f"""As an official assistant for University of Education Lahore, provide a clear response:
124
- Question: {user_input}
125
- Original Answer: {original_answer}
126
- Improved Answer:"""
127
- else:
128
- prompt = f"""As an official assistant for University of Education Lahore, provide a helpful response:
129
- Include relevant details about university policies.
130
- If unsure, direct to official channels.
131
- Question: {user_input}
132
- Official Answer:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  llm_response = query_groq_llm(prompt)
135
 
 
 
136
  if llm_response:
137
- for marker in ["Improved Answer:", "Official Answer:"]:
 
 
138
  if marker in llm_response:
139
- response = llm_response.split(marker)[-1].strip()
140
- break
141
- else:
142
- response = llm_response
 
143
  else:
144
- response = dataset_answers[best_match_idx] if best_score >= 0.65 else """For official information:
145
- 📞 +92-42-99262231-33
146
- ✉️ info@ue.edu.pk
147
- 🌐 ue.edu.pk"""
148
 
149
- return response
 
 
1
  import json
 
 
 
 
 
2
 
 
3
  from sentence_transformers import SentenceTransformer, util
4
+
5
  from groq import Groq
6
+
7
+ from datetime import datetime
8
+
9
+ import os
10
+
11
+ import pandas as pd
12
+
13
  from datasets import load_dataset, Dataset
14
 
15
+ from dotenv import load_dotenv
16
+
17
+ import random
18
+
19
+ import glob
20
+
21
+
22
 
23
  # Load environment variables
24
+
25
  load_dotenv()
26
 
27
+
28
+
29
  # Initialize Groq client
30
+
31
  groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
32
 
33
+
34
+
35
+ # Load similarity model
36
+
37
  similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
38
 
39
+
40
+
41
  # Config
42
+
43
  HF_DATASET_REPO = "midrees2806/unmatched_queries"
44
+
45
  HF_TOKEN = os.getenv("HF_TOKEN")
46
 
47
+
48
+
49
+ # Greeting list
50
+
51
+ GREETINGS = [
52
+
53
+ "hi", "hello", "hey", "good morning", "good afternoon", "good evening",
54
+
55
+ "assalam o alaikum", "salam", "aoa", "hi there",
56
+
57
+ "hey there", "greetings"
58
+
59
+ ]
60
+
61
+
62
+
63
+ # Fixed rephrased unmatched query responses
64
+
65
+ UNMATCHED_RESPONSES = [
66
+
67
+ "Thank you for your query. We’ve forwarded it to our support team and it will be added soon. In the meantime, you can visit the University of Education official website or reach out via the contact details below.\n\n📞 +92-42-99262231-33\n✉️ info@ue.edu.pk\n🌐 https://ue.edu.pk",
68
+
69
+ "We’ve noted your question and it’s in queue for inclusion. For now, please check the University of Education website or contact the administration directly.\n\n📞 +92-42-99262231-33\n✉️ info@ue.edu.pk\n🌐 https://ue.edu.pk",
70
+
71
+ "Your query has been recorded. We’ll update the system with relevant information shortly. Meanwhile, you can visit UE's official site or reach out using the details below:\n\n📞 +92-42-99262231-33\n✉️ info@ue.edu.pk\n🌐 https://ue.edu.pk",
72
+
73
+ "We appreciate your question. It has been forwarded for further processing. Until it’s available here, feel free to visit the official UE website or use the contact options:\n\n📞 +92-42-99262231-33\n✉️ info@ue.edu.pk\n🌐 https://ue.edu.pk"
74
+
75
+ ]
76
+
77
+
78
+
79
+ # Load multiple JSON datasets
80
+
81
  dataset = []
82
+
83
  try:
84
+
85
  json_files = glob.glob('datasets/*.json')
86
+
87
  for file_path in json_files:
88
+
89
  with open(file_path, 'r', encoding='utf-8') as f:
90
+
91
  data = json.load(f)
92
+
93
  if isinstance(data, list):
94
+
95
  for item in data:
96
+
97
  if isinstance(item, dict) and 'Question' in item and 'Answer' in item:
98
+
99
  dataset.append(item)
100
+
101
+ else:
102
+
103
+ print(f"Invalid entry in {file_path}: {item}")
104
+
105
  else:
106
+
107
+ print(f"File {file_path} does not contain a list.")
108
+
109
  except Exception as e:
110
+
111
  print(f"Error loading datasets: {e}")
112
 
113
+
114
+
115
  # Precompute embeddings
116
+
117
  dataset_questions = [item.get("Question", "").lower().strip() for item in dataset]
118
+
119
  dataset_answers = [item.get("Answer", "") for item in dataset]
120
 
121
+ dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
122
+
123
+
 
 
124
 
125
  # Save unmatched queries to Hugging Face
126
+
127
  def manage_unmatched_queries(query: str):
128
+
129
  try:
130
+
131
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
132
+
133
  try:
134
+
135
  ds = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)
136
+
137
  df = ds["train"].to_pandas()
138
+
139
  except:
140
+
141
  df = pd.DataFrame(columns=["Query", "Timestamp", "Processed"])
142
+
143
  if query not in df["Query"].values:
144
+
145
  new_entry = {"Query": query, "Timestamp": timestamp, "Processed": False}
146
+
147
  df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)
148
+
149
  updated_ds = Dataset.from_pandas(df)
150
+
151
  updated_ds.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN)
152
+
153
  except Exception as e:
154
+
155
  print(f"Failed to save query: {e}")
156
 
157
+
158
+
159
  # Query Groq LLM
160
+
161
  def query_groq_llm(prompt, model_name="llama3-70b-8192"):
162
+
163
  try:
164
+
165
  chat_completion = groq_client.chat.completions.create(
166
+
167
  messages=[{
168
+
169
  "role": "user",
170
+
171
  "content": prompt
172
+
173
  }],
174
+
175
  model=model_name,
176
+
177
  temperature=0.7,
178
+
179
  max_tokens=500
180
+
181
  )
182
+
183
  return chat_completion.choices[0].message.content.strip()
184
+
185
  except Exception as e:
186
+
187
  print(f"Error querying Groq API: {e}")
188
+
189
  return ""
190
 
191
+
192
+
193
+ # Main logic function to be called from Gradio
194
+
195
  def get_best_answer(user_input):
196
+
197
  if not user_input.strip():
198
+
199
  return "Please enter a valid question."
200
+
201
+
202
+
203
  user_input_lower = user_input.lower().strip()
204
 
205
+
206
+
207
+ if len(user_input_lower.split()) < 3 and not any(greet in user_input_lower for greet in GREETINGS):
208
+
209
+ return "Please ask your question properly with at least 3 words."
210
+
211
+
212
+
213
+ if any(keyword in user_input_lower for keyword in ["fee structure", "fees structure", "semester fees", "semester fee"]):
214
+
215
  return (
216
+
217
  "💰 For complete and up-to-date fee details for this program, we recommend visiting the official University of Education fee structure page.\n"
218
+
219
+ "You'll find comprehensive information regarding tuition, admission charges, and other applicable fees there.\n"
220
+
221
  "🔗 https://ue.edu.pk/allfeestructure.php"
222
+
223
  )
224
 
225
+
 
 
226
 
227
  user_embedding = similarity_model.encode(user_input_lower, convert_to_tensor=True)
228
+
229
  similarities = util.pytorch_cos_sim(user_embedding, dataset_embeddings)[0]
230
+
231
  best_match_idx = similarities.argmax().item()
232
+
233
  best_score = similarities[best_match_idx].item()
234
 
235
+
236
+
237
  if best_score < 0.65:
238
+
239
  manage_unmatched_queries(user_input)
240
 
241
+ return random.choice(UNMATCHED_RESPONSES)
242
+
243
+
244
+
245
+ original_answer = dataset_answers[best_match_idx]
246
+
247
+ prompt = f"""Name is UOE AI Assistant! You are an official assistant for the University of Education Lahore.
248
+
249
+ Rephrase the following official answer clearly and professionally.
250
+
251
+ Use structured formatting (like headings, bullet points, or numbered lists) where appropriate.
252
+
253
+ DO NOT add any new or extra information. ONLY rephrase and improve the clarity and formatting of the original answer.
254
+
255
+ ### Question:
256
+
257
+ {user_input}
258
+
259
+ ### Original Answer:
260
+
261
+ {original_answer}
262
+
263
+ ### Rephrased Answer:
264
+
265
+ """
266
+
267
+
268
 
269
  llm_response = query_groq_llm(prompt)
270
 
271
+
272
+
273
  if llm_response:
274
+
275
+ for marker in ["Improved Answer:", "Official Answer:", "Rephrased Answer:"]:
276
+
277
  if marker in llm_response:
278
+
279
+ return llm_response.split(marker)[-1].strip()
280
+
281
+ return llm_response
282
+
283
  else:
 
 
 
 
284
 
285
+ return dataset_answers[best_match_idx]
286
+