Spaces:
Sleeping
Sleeping
Update rag.py
Browse files
rag.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
import json
|
| 2 |
-
|
| 3 |
-
from groq import Groq
|
| 4 |
-
from datetime import datetime
|
| 5 |
import os
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
-
from
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Load environment variables
|
| 13 |
load_dotenv()
|
|
@@ -25,16 +27,7 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
| 25 |
# Greeting list
|
| 26 |
GREETINGS = [
|
| 27 |
"hi", "hello", "hey", "good morning", "good afternoon", "good evening",
|
| 28 |
-
"assalam o alaikum", "salam", "aoa", "hi there",
|
| 29 |
-
"hey there", "greetings"
|
| 30 |
-
]
|
| 31 |
-
|
| 32 |
-
# Fixed rephrased unmatched query responses
|
| 33 |
-
UNMATCHED_RESPONSES = [
|
| 34 |
-
"Thank you for your query. We’ve forwarded it to our support team and it will be added soon. In the meantime, you can visit the University of Education official website or reach out via the contact details below.\n\n📞 +92-42-99262231-33\n✉️ info@ue.edu.pk\n🌐 https://ue.edu.pk",
|
| 35 |
-
"We’ve noted your question and it’s in queue for inclusion. For now, please check the University of Education website or contact the administration directly.\n\n📞 +92-42-99262231-33\n✉️ info@ue.edu.pk\n🌐 https://ue.edu.pk",
|
| 36 |
-
"Your query has been recorded. We’ll update the system with relevant information shortly. Meanwhile, you can visit UE's official site or reach out using the details below:\n\n📞 +92-42-99262231-33\n✉️ info@ue.edu.pk\n🌐 https://ue.edu.pk",
|
| 37 |
-
"We appreciate your question. It has been forwarded for further processing. Until it’s available here, feel free to visit the official UE website or use the contact options:\n\n📞 +92-42-99262231-33\n✉️ info@ue.edu.pk\n🌐 https://ue.edu.pk"
|
| 38 |
]
|
| 39 |
|
| 40 |
# Load multiple JSON datasets
|
|
@@ -48,10 +41,6 @@ try:
|
|
| 48 |
for item in data:
|
| 49 |
if isinstance(item, dict) and 'Question' in item and 'Answer' in item:
|
| 50 |
dataset.append(item)
|
| 51 |
-
else:
|
| 52 |
-
print(f"Invalid entry in {file_path}: {item}")
|
| 53 |
-
else:
|
| 54 |
-
print(f"File {file_path} does not contain a list.")
|
| 55 |
except Exception as e:
|
| 56 |
print(f"Error loading datasets: {e}")
|
| 57 |
|
|
@@ -60,8 +49,8 @@ dataset_questions = [item.get("Question", "").lower().strip() for item in datase
|
|
| 60 |
dataset_answers = [item.get("Answer", "") for item in dataset]
|
| 61 |
dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
|
| 62 |
|
| 63 |
-
# Save unmatched queries to Hugging Face
|
| 64 |
def manage_unmatched_queries(query: str):
|
|
|
|
| 65 |
try:
|
| 66 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 67 |
try:
|
|
@@ -69,6 +58,7 @@ def manage_unmatched_queries(query: str):
|
|
| 69 |
df = ds["train"].to_pandas()
|
| 70 |
except:
|
| 71 |
df = pd.DataFrame(columns=["Query", "Timestamp", "Processed"])
|
|
|
|
| 72 |
if query not in df["Query"].values:
|
| 73 |
new_entry = {"Query": query, "Timestamp": timestamp, "Processed": False}
|
| 74 |
df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)
|
|
@@ -77,67 +67,65 @@ def manage_unmatched_queries(query: str):
|
|
| 77 |
except Exception as e:
|
| 78 |
print(f"Failed to save query: {e}")
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
try:
|
| 83 |
chat_completion = groq_client.chat.completions.create(
|
| 84 |
-
messages=[
|
| 85 |
-
"role": "
|
| 86 |
-
"content": prompt
|
| 87 |
-
|
| 88 |
-
model=
|
| 89 |
-
temperature=0.
|
| 90 |
-
max_tokens=
|
| 91 |
)
|
| 92 |
return chat_completion.choices[0].message.content.strip()
|
| 93 |
except Exception as e:
|
| 94 |
print(f"Error querying Groq API: {e}")
|
| 95 |
return ""
|
| 96 |
|
| 97 |
-
# Main logic function to be called from Gradio
|
| 98 |
def get_best_answer(user_input):
|
| 99 |
if not user_input.strip():
|
| 100 |
return "Please enter a valid question."
|
| 101 |
|
| 102 |
user_input_lower = user_input.lower().strip()
|
| 103 |
|
|
|
|
| 104 |
if len(user_input_lower.split()) < 3 and not any(greet in user_input_lower for greet in GREETINGS):
|
| 105 |
-
return "Please ask your question properly with at least 3 words."
|
| 106 |
|
|
|
|
| 107 |
if any(keyword in user_input_lower for keyword in ["fee structure", "fees structure", "semester fees", "semester fee"]):
|
| 108 |
return (
|
| 109 |
-
"💰
|
| 110 |
-
"
|
| 111 |
-
"🔗 https://drive.google.com/file/d/1B30FKoP6GrkS9pQk10PWKCwcjco5E9Cc/view"
|
| 112 |
)
|
| 113 |
|
|
|
|
| 114 |
user_embedding = similarity_model.encode(user_input_lower, convert_to_tensor=True)
|
| 115 |
similarities = util.pytorch_cos_sim(user_embedding, dataset_embeddings)[0]
|
| 116 |
best_match_idx = similarities.argmax().item()
|
| 117 |
best_score = similarities[best_match_idx].item()
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
manage_unmatched_queries(user_input)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
Use structured formatting (like headings, bullet points, or numbered lists) where appropriate.
|
| 127 |
-
DO NOT add any new or extra information. ONLY rephrase and improve the clarity and formatting of the original answer.
|
| 128 |
-
### Question:
|
| 129 |
-
{user_input}
|
| 130 |
-
### Original Answer:
|
| 131 |
-
{original_answer}
|
| 132 |
-
### Rephrased Answer:
|
| 133 |
-
"""
|
| 134 |
-
|
| 135 |
-
llm_response = query_groq_llm(prompt)
|
| 136 |
|
|
|
|
| 137 |
if llm_response:
|
| 138 |
-
for marker in ["Improved Answer:", "Official Answer:", "Rephrased Answer:"]:
|
| 139 |
-
if marker in llm_response:
|
| 140 |
-
return llm_response.split(marker)[-1].strip()
|
| 141 |
return llm_response
|
| 142 |
else:
|
| 143 |
-
return
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
+
import glob
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
+
import random
|
| 5 |
import pandas as pd
|
| 6 |
+
from datetime import datetime
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
# Core AI Libraries
|
| 10 |
+
from sentence_transformers import SentenceTransformer, util
|
| 11 |
+
from groq import Groq
|
| 12 |
+
from datasets import load_dataset, Dataset
|
| 13 |
|
| 14 |
# Load environment variables
|
| 15 |
load_dotenv()
|
|
|
|
| 27 |
# Greeting list
|
| 28 |
GREETINGS = [
|
| 29 |
"hi", "hello", "hey", "good morning", "good afternoon", "good evening",
|
| 30 |
+
"assalam o alaikum", "salam", "aoa", "hi there", "hey there", "greetings"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
]
|
| 32 |
|
| 33 |
# Load multiple JSON datasets
|
|
|
|
| 41 |
for item in data:
|
| 42 |
if isinstance(item, dict) and 'Question' in item and 'Answer' in item:
|
| 43 |
dataset.append(item)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
except Exception as e:
|
| 45 |
print(f"Error loading datasets: {e}")
|
| 46 |
|
|
|
|
| 49 |
dataset_answers = [item.get("Answer", "") for item in dataset]
|
| 50 |
dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
|
| 51 |
|
|
|
|
| 52 |
def manage_unmatched_queries(query: str):
|
| 53 |
+
"""Logs unknown queries to Hugging Face Hub."""
|
| 54 |
try:
|
| 55 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 56 |
try:
|
|
|
|
| 58 |
df = ds["train"].to_pandas()
|
| 59 |
except:
|
| 60 |
df = pd.DataFrame(columns=["Query", "Timestamp", "Processed"])
|
| 61 |
+
|
| 62 |
if query not in df["Query"].values:
|
| 63 |
new_entry = {"Query": query, "Timestamp": timestamp, "Processed": False}
|
| 64 |
df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)
|
|
|
|
| 67 |
except Exception as e:
|
| 68 |
print(f"Failed to save query: {e}")
|
| 69 |
|
| 70 |
+
def query_groq_llm(prompt, system_message):
|
| 71 |
+
"""Utility to call Groq Llama 3 API."""
|
| 72 |
try:
|
| 73 |
chat_completion = groq_client.chat.completions.create(
|
| 74 |
+
messages=[
|
| 75 |
+
{"role": "system", "content": system_message},
|
| 76 |
+
{"role": "user", "content": prompt}
|
| 77 |
+
],
|
| 78 |
+
model="llama3-70b-8192",
|
| 79 |
+
temperature=0.6,
|
| 80 |
+
max_tokens=800
|
| 81 |
)
|
| 82 |
return chat_completion.choices[0].message.content.strip()
|
| 83 |
except Exception as e:
|
| 84 |
print(f"Error querying Groq API: {e}")
|
| 85 |
return ""
|
| 86 |
|
|
|
|
| 87 |
def get_best_answer(user_input):
|
| 88 |
if not user_input.strip():
|
| 89 |
return "Please enter a valid question."
|
| 90 |
|
| 91 |
user_input_lower = user_input.lower().strip()
|
| 92 |
|
| 93 |
+
# 1. Length/Greeting Check
|
| 94 |
if len(user_input_lower.split()) < 3 and not any(greet in user_input_lower for greet in GREETINGS):
|
| 95 |
+
return "Please ask your question properly with at least 3 words so I can assist you better."
|
| 96 |
|
| 97 |
+
# 2. Hardcoded Keyword Check (Fees)
|
| 98 |
if any(keyword in user_input_lower for keyword in ["fee structure", "fees structure", "semester fees", "semester fee"]):
|
| 99 |
return (
|
| 100 |
+
"💰 **Official Fee Information**\n\n"
|
| 101 |
+
"For the most complete and up-to-date fee details for your program at the University of Education Lahore, please visit the official link:\n"
|
| 102 |
+
"🔗 [View Fee Structure](https://drive.google.com/file/d/1B30FKoP6GrkS9pQk10PWKCwcjco5E9Cc/view)"
|
| 103 |
)
|
| 104 |
|
| 105 |
+
# 3. Vector Similarity Search
|
| 106 |
user_embedding = similarity_model.encode(user_input_lower, convert_to_tensor=True)
|
| 107 |
similarities = util.pytorch_cos_sim(user_embedding, dataset_embeddings)[0]
|
| 108 |
best_match_idx = similarities.argmax().item()
|
| 109 |
best_score = similarities[best_match_idx].item()
|
| 110 |
|
| 111 |
+
# 4. DECISION LOGIC
|
| 112 |
+
if best_score >= 0.65:
|
| 113 |
+
# PATH A: High similarity (Rephrase Dataset)
|
| 114 |
+
original_answer = dataset_answers[best_match_idx]
|
| 115 |
+
system_message = "You are the UOE AI Assistant. Your job is to rephrase the provided official answer to make it more attractive, using bold text, bullet points, and clear headings. Do NOT change the facts."
|
| 116 |
+
prompt = f"User Question: {user_input}\nOfficial Context: {original_answer}\n\nRephrase this beautifully:"
|
| 117 |
+
else:
|
| 118 |
+
# PATH B: Low similarity (LLM General Knowledge + Logging)
|
| 119 |
manage_unmatched_queries(user_input)
|
| 120 |
+
system_message = "You are the UOE AI Assistant for the University of Education Lahore. Answer the user's question based on your general knowledge of the university. If you are unsure about specific dates or costs, tell them to contact info@ue.edu.pk."
|
| 121 |
+
prompt = f"The following question about University of Education (UE) was asked: {user_input}\n\nPlease provide a helpful and professional response:"
|
| 122 |
+
|
| 123 |
+
# 5. Execute LLM Call
|
| 124 |
+
llm_response = query_groq_llm(prompt, system_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
# 6. Fallback Strategy
|
| 127 |
if llm_response:
|
|
|
|
|
|
|
|
|
|
| 128 |
return llm_response
|
| 129 |
else:
|
| 130 |
+
# If LLM fails, return raw dataset answer if score was high, else return contact info
|
| 131 |
+
return dataset_answers[best_match_idx] if best_score >= 0.65 else "I'm having trouble connecting to the server. Please contact info@ue.edu.pk."
|