Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import re | |
| UNKNOWN_FALLBACK = "Thank you for your inquiry. Unfortunately, I am unable to provide an answer to your question at this time. For accurate and up-to-date information, please contact the admissions office at admission@nu.edu.eg" | |
| def _coerce_embedding(vec): | |
| if hasattr(vec, "tolist"): | |
| vec = vec.tolist() | |
| if isinstance(vec, list): | |
| return [float(v) for v in vec] | |
| return None | |
| # ========================= | |
| # Category Detection | |
| # ========================= | |
| def detect_category(question: str) -> str: | |
| """Detect the category of a question based on keywords.""" | |
| query_lower = question.lower() | |
| # Category keywords based on Categories.txt | |
| category_keywords = { | |
| "Admissions": ["apply", "admission", "accept", "requirements", "application", "enroll"], | |
| "Fees": ["fee", "tuition", "cost", "payment", "credit", "price", "pay", "refund"], | |
| "Academics": ["gpa", "grades", "scores", "grade", "cgpa", "dean"], | |
| "Academic Advising": ["advisor", "track", "course", "major", "register", "summer course"], | |
| "IT & Systems": ["portal", "moodle", "login", "system", "technical", "support"], | |
| "Emails": ["email", "gmail", "outlook", "mail", "inbox", "address", "contact email"], | |
| } | |
| # Lightweight Arabic keyword support for common student queries. | |
| arabic_category_keywords = { | |
| "Admissions": ["تقديم", "قبول", "التحاق", "شروط", "متطلبات", "مستندات", "اوراق", "اختبار", "placement", "معادله", "تحويل"], | |
| "Fees": ["رسوم", "مصاريف", "سعر", "تكلفة", "قسط", "ساعه", "ساعة", "credit", "tuition", "refund", "منحه", "منح"], | |
| "Academics": ["معدل", "gpa", "cgpa", "درجات", "انسحاب", "drop", "withdraw", "حضور", "غياب", "اختبار", "امتحان"], | |
| "Academic Advising": ["ادفيزور", "مرشد", "ارشاد", "مقررات", "تسجيل", "ماجور", "تخصص", "خطة", "تراك", "self service"], | |
| "IT & Systems": ["مودل", "moodle", "بورتال", "بوابه", "بوابة", "سيستم", "تسجيل الدخول", "portal", "حساب", "تقني", "itsupport"], | |
| "Emails": ["ايميل", "بريد", "outlook", "email", "admission@", "nu.edu.eg"], | |
| } | |
| # Count keyword matches for each category | |
| category_scores = {} | |
| for category, keywords in category_keywords.items(): | |
| score = sum(1 for keyword in keywords if keyword in query_lower) | |
| if score > 0: | |
| category_scores[category] = score | |
| # Return category with highest score, or None if no matches | |
| if category_scores: | |
| return max(category_scores, key=category_scores.get) | |
| for category, keywords in arabic_category_keywords.items(): | |
| score = sum(1 for keyword in keywords if keyword in query_lower) | |
| if score > 0: | |
| category_scores[category] = score | |
| if category_scores: | |
| return max(category_scores, key=category_scores.get) | |
| return None | |
| # ========================= | |
| # Retrieval (Chroma / Local) | |
| # ========================= | |
| def retrieve_with_chroma(query_embedding, top_k=5, category_filter=None): | |
| """Retrieve relevant Q&A pairs from Chroma with optional category filtering.""" | |
| try: | |
| import chromadb | |
| # Try to get a ChromaDB client | |
| client = None | |
| CHROMA_PERSIST_DIR = os.getenv("CHROMA_PERSIST_DIR", "./chroma_db") | |
| try: | |
| client = chromadb.PersistentClient(path=CHROMA_PERSIST_DIR) | |
| except Exception: | |
| try: | |
| from chromadb.config import Settings | |
| client = chromadb.Client(Settings(persist_directory=CHROMA_PERSIST_DIR)) | |
| except Exception: | |
| try: | |
| client = chromadb.EphemeralClient() | |
| except Exception: | |
| return [], [] | |
| # Get qa_knowledge collection (stores JSON Q&A data) | |
| col = None | |
| try: | |
| col = client.get_collection("qa_knowledge") | |
| except Exception: | |
| return [], [] | |
| # Apply category filter if provided (e.g., only Fees, Admissions, etc.) | |
| if category_filter: | |
| results = col.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k, | |
| where={"category": category_filter} | |
| ) | |
| else: | |
| results = col.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k | |
| ) | |
| docs = results.get("documents", [[]])[0] | |
| metas = results.get("metadatas", [[]])[0] | |
| return docs, metas | |
| except Exception as e: | |
| # If there's an error, print it for debugging and return empty | |
| print(f"ChromaDB error: {e}") | |
| return [], [] | |
| def initialize_chroma_from_json(embed_fn, collection_name="qa_knowledge"): | |
| """Ensure Chroma has indexed Q&A entries from data.json.""" | |
| try: | |
| import chromadb | |
| except Exception as e: | |
| print(f"ChromaDB import error: {e}") | |
| return False | |
| cwd = os.getcwd() | |
| json_path = os.path.join(cwd, "data.json") | |
| if not os.path.exists(json_path): | |
| print("Chroma init skipped: data.json not found") | |
| return False | |
| try: | |
| CHROMA_PERSIST_DIR = os.getenv("CHROMA_PERSIST_DIR", "./chroma_db") | |
| try: | |
| client = chromadb.PersistentClient(path=CHROMA_PERSIST_DIR) | |
| except Exception: | |
| from chromadb.config import Settings | |
| client = chromadb.Client(Settings(persist_directory=CHROMA_PERSIST_DIR)) | |
| # Always rebuild from current data.json to avoid stale vectors. | |
| try: | |
| client.delete_collection(collection_name) | |
| except Exception: | |
| pass | |
| col = client.create_collection(collection_name) | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| ids = [] | |
| docs = [] | |
| metas = [] | |
| for idx, entry in enumerate(data): | |
| qtext = str(entry.get("question", "")).strip() | |
| atext = str(entry.get("answer", "")).strip() | |
| if not qtext and not atext: | |
| continue | |
| qa_id = str(entry.get("id", idx + 1)) | |
| category = entry.get("category", "General") | |
| ids.append(f"qa_{qa_id}_{idx}") | |
| docs.append(f"Question: {qtext}\nAnswer: {atext}") | |
| metas.append({ | |
| "source": json_path, | |
| "qa_id": qa_id, | |
| "category": category, | |
| "question": qtext, | |
| "answer": atext, | |
| }) | |
| if not docs: | |
| print("Chroma init skipped: no valid Q&A entries") | |
| return False | |
| embeddings = [] | |
| batch_size = 32 | |
| for i in range(0, len(docs), batch_size): | |
| batch_docs = docs[i:i + batch_size] | |
| batch_emb = embed_fn(batch_docs) | |
| if not isinstance(batch_emb, list) or len(batch_emb) != len(batch_docs): | |
| # Fallback to per-item embedding if backend returns unexpected shape. | |
| batch_emb = [embed_fn([d])[0] for d in batch_docs] | |
| for emb in batch_emb: | |
| emb_vec = _coerce_embedding(emb) | |
| if emb_vec is None: | |
| print("Chroma init aborted: invalid embedding vector") | |
| return False | |
| embeddings.append(emb_vec) | |
| if len(embeddings) != len(docs): | |
| print("Chroma init aborted: embedding count mismatch") | |
| return False | |
| col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embeddings) | |
| print(f"Chroma initialized with {len(docs)} entries") | |
| return True | |
| except Exception as e: | |
| print(f"Chroma initialization error: {e}") | |
| return False | |
| def local_retrieve(question, top_k=3, category_filter=None): | |
| """Keyword-overlap retrieval over JSON data on disk with optional category filtering.""" | |
| cwd = os.getcwd() | |
| json_path = os.path.join(cwd, "data.json") | |
| items = [] | |
| metadatas = [] | |
| if os.path.exists(json_path): | |
| try: | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| for entry in data: | |
| category = entry.get("category", "General") | |
| # Apply category filter if provided | |
| if category_filter and category != category_filter: | |
| continue | |
| qa_id = entry.get("id", "") | |
| qtext = str(entry.get("question", "")).strip() | |
| atext = str(entry.get("answer", "")).strip() | |
| if qtext or atext: | |
| combined = f"Question: {qtext}\nAnswer: {atext}" | |
| items.append(combined) | |
| # Store metadata with each item | |
| metadatas.append({ | |
| "source": json_path, | |
| "qa_id": str(qa_id), | |
| "category": category, | |
| "question": qtext, | |
| "answer": atext | |
| }) | |
| except Exception as e: | |
| print(f"Error reading data.json: {e}") | |
| items = [] | |
| metadatas = [] | |
| # Use items directly (already have metadata) | |
| if not items: | |
| return [], [] | |
| # Normalized keyword-overlap scoring over both question and answer fields. | |
| q_tokens = set(_normalize_question(question).split()) | |
| if not q_tokens: | |
| return [], [] | |
| scores = [] | |
| for meta in metadatas: | |
| qtext = str(meta.get("question", "")) | |
| atext = str(meta.get("answer", "")) | |
| q_field_tokens = set(_normalize_question(qtext).split()) | |
| a_field_tokens = set(_normalize_question(atext).split()) | |
| # Weight question overlap higher than answer overlap. | |
| q_overlap = len(q_tokens & q_field_tokens) | |
| a_overlap = len(q_tokens & a_field_tokens) | |
| score = (2.0 * q_overlap) + (1.0 * a_overlap) | |
| scores.append(score) | |
| ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) | |
| top = [items[i] for i, s in ranked[:top_k] if s > 0] | |
| top_meta = [metadatas[i] for i, s in ranked[:top_k] if s > 0] | |
| return top, top_meta | |
| def _normalize_question(text: str) -> str: | |
| """Normalize questions for deterministic exact-match lookup.""" | |
| text = (text or "").lower() | |
| # Arabic orthographic normalization to improve matching robustness. | |
| text = re.sub("[إأآا]", "ا", text) | |
| text = re.sub("ى", "ي", text) | |
| text = re.sub("ؤ", "ء", text) | |
| text = re.sub("ئ", "ء", text) | |
| text = re.sub("ة", "ه", text) | |
| # Keep both Latin and Arabic word characters so exact match works bilingually. | |
| return " ".join(re.findall(r"[\w\u0600-\u06FF]+", text, flags=re.UNICODE)) | |
| def local_exact_match(question, category_filter=None): | |
| """Return exact question match from data.json if available.""" | |
| cwd = os.getcwd() | |
| json_path = os.path.join(cwd, "data.json") | |
| if not os.path.exists(json_path): | |
| return None, None | |
| target = _normalize_question(question) | |
| if not target: | |
| return None, None | |
| try: | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| candidates = [] | |
| for entry in data: | |
| category = entry.get("category", "General") | |
| if category_filter and category != category_filter: | |
| continue | |
| qtext = str(entry.get("question", "")).strip() | |
| if _normalize_question(qtext) == target: | |
| atext = str(entry.get("answer", "")).strip() | |
| doc = f"Question: {qtext}\nAnswer: {atext}" | |
| meta = { | |
| "source": json_path, | |
| "qa_id": str(entry.get("id", "")), | |
| "category": category, | |
| "question": qtext, | |
| "answer": atext, | |
| } | |
| return doc, meta | |
| candidates.append(entry) | |
| # Near-exact fallback for paraphrases/translations. | |
| target_tokens = set(target.split()) | |
| if not target_tokens: | |
| return None, None | |
| best_entry = None | |
| best_score = 0.0 | |
| for entry in candidates: | |
| qtext = str(entry.get("question", "")).strip() | |
| q_norm = _normalize_question(qtext) | |
| if not q_norm: | |
| continue | |
| q_tokens = set(q_norm.split()) | |
| if not q_tokens: | |
| continue | |
| overlap = len(target_tokens & q_tokens) | |
| union = len(target_tokens | q_tokens) | |
| score = overlap / max(union, 1) | |
| if score > best_score: | |
| best_score = score | |
| best_entry = entry | |
| if best_entry is not None and best_score >= 0.45: | |
| qtext = str(best_entry.get("question", "")).strip() | |
| atext = str(best_entry.get("answer", "")).strip() | |
| doc = f"Question: {qtext}\nAnswer: {atext}" | |
| meta = { | |
| "source": json_path, | |
| "qa_id": str(best_entry.get("id", "")), | |
| "category": best_entry.get("category", "General"), | |
| "question": qtext, | |
| "answer": atext, | |
| } | |
| return doc, meta | |
| except Exception as e: | |
| print(f"Error in exact match lookup: {e}") | |
| return None, None | |
| # ========================= | |
| # Context Formatting | |
| # ========================= | |
| def format_context(docs, _metas): | |
| """Format retrieved documents with metadata.""" | |
| formatted = [] | |
| for i, doc in enumerate(docs): | |
| formatted.append(f"- {doc}") | |
| return "\n".join(formatted) | |
| # ========================= | |
| # Save Results (Optional) | |
| # ========================= | |
| def save_result(query, answer, sources): | |
| """Save query results to a log file.""" | |
| os.makedirs("logs", exist_ok=True) | |
| record = { | |
| "query": query, | |
| "answer": answer, | |
| "sources": sources | |
| } | |
| with open("logs/history.jsonl", "a", encoding="utf-8") as f: | |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") |