import os import json import uuid import time from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional from dotenv import load_dotenv load_dotenv() from fastmcp import FastMCP # -------- Config: data root resolution (robust, works locally & on HF) -------- def resolve_data_root() -> Path: # 1) Env var wins if set env = os.getenv("DATA_ROOT") if env: p = Path(env).expanduser().resolve() if p.exists(): return p # 2) Try common repo-relative locations (first that exists wins) candidates = [ Path.cwd() / "TRAINING DATA", # your exact folder with space Path.cwd() / "training_data", Path.cwd() / "data", ] for c in candidates: if c.exists(): return c.resolve() # 3) Last resort: create ./TRAINING DATA to avoid crashes fallback = Path.cwd() / "TRAINING DATA" fallback.mkdir(parents=True, exist_ok=True) return fallback.resolve() DATA_ROOT: Path = resolve_data_root() STUDENTS_DIR: Path = DATA_ROOT / "students" # -------- Small utilities -------- def _receipt() -> Dict[str, Any]: return { "tool_used": True, "server_time": datetime.now(timezone.utc).isoformat(), "request_id": str(uuid.uuid4()), "data_root": str(DATA_ROOT), } def _norm(s: Optional[str]) -> str: return (s or "").strip() def _ci_contains(hay: Optional[str], needle: Optional[str]) -> bool: return _norm(needle).lower() in _norm(hay).lower() # -------- Load metadata (from DATA_ROOT/students/*/metadata.json and/or DATA_ROOT/metadata.json) -------- # In-memory index: { student_name_lower: {"name":..., "email":..., ...} } _METADATA_BY_STUDENT: Dict[str, Dict[str, Any]] = {} def _load_all_metadata() -> None: global _METADATA_BY_STUDENT _METADATA_BY_STUDENT = {} # A) Per-student folders if STUDENTS_DIR.exists(): for student_dir in sorted(STUDENTS_DIR.iterdir()): if not student_dir.is_dir(): continue meta_file = student_dir / "metadata.json" if meta_file.exists(): try: data = json.loads(meta_file.read_text(encoding="utf-8")) except Exception: continue # Ensure a 'name' field; default to directory name name = data.get("name") or student_dir.name data["name"] = name data["__path"] = str(meta_file) _METADATA_BY_STUDENT[_norm(name).lower()] = data # B) Optional top-level metadata.json (may contain a list or a dict of students) top_meta = DATA_ROOT / "metadata.json" if top_meta.exists(): try: blob = json.loads(top_meta.read_text(encoding="utf-8")) # Accept either: # - {"students":[{...},{...}]} # - [{"name":..., ...}, ...] # - {"": {...}, ...} candidates: List[Dict[str, Any]] = [] if isinstance(blob, dict) and "students" in blob and isinstance(blob["students"], list): candidates = blob["students"] elif isinstance(blob, list): candidates = blob elif isinstance(blob, dict): for k, v in blob.items(): if isinstance(v, dict): v.setdefault("name", k) candidates.append(v) for data in candidates: name = data.get("name") if not name: continue data["__path"] = str(top_meta) _METADATA_BY_STUDENT[_norm(name).lower()] = data except Exception: pass # Initial load _load_all_metadata() # -------- OpenAI embeddings (for Pinecone RAG) -------- from openai import OpenAI _openai_client: Optional[OpenAI] = None _EMBED_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") def _embed(texts: List[str]) -> List[List[float]]: """ Embed a batch of strings using OpenAI embeddings. """ global _openai_client if _openai_client is None: _openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # OpenAI Python SDK v1 returns .data with embeddings in order resp = _openai_client.embeddings.create(model=_EMBED_MODEL, input=texts) return [d.embedding for d in resp.data] # -------- Pinecone client -------- from pinecone import Pinecone as _Pinecone _pine: Optional[_Pinecone] = None def _pc() -> _Pinecone: global _pine if _pine is None: api_key = os.getenv("PINECONE_API_KEY") if not api_key: raise RuntimeError("PINECONE_API_KEY not set") _pine = _Pinecone(api_key=api_key) return _pine def _pine_index(): pc = _pc() index_name = os.getenv("PINECONE_INDEX_NAME") if not index_name: raise RuntimeError("PINECONE_INDEX_NAME not set") return pc.Index(index_name) # -------- MCP server -------- mcp = FastMCP("ProjectRAGServer") @mcp.tool def add(a: int, b: int) -> Dict[str, Any]: """Add two numbers.""" return {"result": int(a) + int(b), "_receipt": _receipt()} @mcp.tool def list_students() -> Dict[str, Any]: """Return all known student names.""" names = sorted(v.get("name") for v in _METADATA_BY_STUDENT.values() if v.get("name")) return {"students": names, "count": len(names), "_receipt": _receipt()} @mcp.tool def get_student_metadata(name: str) -> Dict[str, Any]: """Return full metadata for a student by name (case-insensitive).""" key = _norm(name).lower() data = _METADATA_BY_STUDENT.get(key) if not data: return {"error": f"Student '{name}' not found.", "_receipt": _receipt()} return {"metadata": data, "_receipt": _receipt()} @mcp.tool def get_student_email(name: str) -> Dict[str, Any]: """Return the email address for a student by name.""" key = _norm(name).lower() data = _METADATA_BY_STUDENT.get(key) if not data: return {"error": f"Student '{name}' not found.", "_receipt": _receipt()} email = data.get("email") if not email: return {"error": f"No email in metadata for '{data.get('name')}'.", "_receipt": _receipt()} return {"name": data.get("name"), "email": email, "_receipt": _receipt()} @mcp.tool def search_student_by_field(field: str, value: str) -> Dict[str, Any]: """ Case-insensitive contains() search across any metadata field. Example: field='department', value='Computer' """ f = _norm(field) val = _norm(value) if not f: return {"error": "Field must be provided.", "_receipt": _receipt()} matches: List[Dict[str, Any]] = [] for meta in _METADATA_BY_STUDENT.values(): if f not in meta: continue v = meta.get(f) # Allow both strings and list-of-strings if isinstance(v, str) and _ci_contains(v, val): matches.append({"name": meta.get("name"), "match_value": v, "metadata": meta}) elif isinstance(v, list) and any(_ci_contains(x, val) for x in v if isinstance(x, str)): matches.append({"name": meta.get("name"), "match_value": v, "metadata": meta}) return {"matches": matches, "count": len(matches), "_receipt": _receipt()} @mcp.tool def reload_metadata() -> Dict[str, Any]: """Reload metadata from disk (useful after updating files).""" t0 = time.time() _load_all_metadata() dt = round((time.time() - t0) * 1000.0, 2) return {"ok": True, "students": len(_METADATA_BY_STUDENT), "ms": dt, "_receipt": _receipt()} @mcp.tool def search_rag(query: str, top_k: int = 3, namespace: Optional[str] = None) -> Dict[str, Any]: """ Semantic search over your Pinecone index using OpenAI embeddings. Returns top_k matches with metadata. Env required: OPENAI_API_KEY, PINECONE_API_KEY, PINECONE_INDEX_NAME Optional: OPENAI_EMBEDDING_MODEL (default: text-embedding-3-small) """ q = _norm(query) if not q: return {"error": "Query must not be empty.", "_receipt": _receipt()} try: vec = _embed([q])[0] except Exception as e: return {"error": f"Embedding failed: {e}", "_receipt": _receipt()} try: idx = _pine_index() except Exception as e: return {"error": f"Pinecone init failed: {e}", "_receipt": _receipt()} try: resp = idx.query( vector=vec, top_k=max(1, min(int(top_k), 50)), include_metadata=True, namespace=namespace or None, ) except Exception as e: return {"error": f"Pinecone query failed: {e}", "_receipt": _receipt()} results = [] for m in getattr(resp, "matches", []) or []: results.append({ "id": getattr(m, "id", None), "score": getattr(m, "score", None), "metadata": getattr(m, "metadata", None), }) return { "query": q, "model": _EMBED_MODEL, "top_k": top_k, "results": results, "_receipt": _receipt(), } @mcp.tool def find_candidates( topics: List[str] | str, n: int = 2, prefer_fields: Optional[List[str]] = None, ensure_distinct_projects: bool = True, top_k_per_topic: int = 5, namespace: Optional[str] = None, ) -> Dict[str, Any]: """ Find N student candidates for the given topic(s). 1) Try metadata search first across prefer_fields (default sensible set). 2) If still short, fallback to Pinecone RAG per topic. 3) Aggregate by student with evidence; enforce distinct projects if requested. Args: topics: A string or list of topic strings (e.g., "reinforcement learning" or ["RL","traffic"]). n: number of distinct students to return. prefer_fields: metadata fields to search (case-insensitive contains). Defaults to ["research_interests", "skills", "keywords", "department", "project_title", "summary"]. ensure_distinct_projects: if True, avoid picking >1 student with the same project_title. top_k_per_topic: RAG candidates to pull per topic. namespace: optional Pinecone namespace. Returns: { "requested_topics": [ ... ], "candidates": [ { "student": "Full Name", "email": "name@example.com", "project_title": "...", "coverage": 2, # how many topics this student matched "score": 0.74, # aggregate/max score from RAG/metadata (metadata hits use 1.0) "evidence": [ { "topic":"...", "source":"metadata|rag", "field":"...", "snippet":"...", "score":1.0 }, ... ], "metadata": {...} # full student metadata }, ... ], "_receipt": {...} } """ # --- normalize topics --- if isinstance(topics, str): topic_list = [_norm(topics)] else: topic_list = [_norm(t) for t in topics if _norm(t)] topic_list = [t for t in topic_list if t] if not topic_list: return {"error": "No non-empty topics provided.", "_receipt": _receipt()} # --- fields to prefer in metadata search --- fields = prefer_fields or [ "research_interests", "skills", "keywords", "department", "project_title", "summary", "areas", "tags" ] # --- accumulator keyed by student key (lowercased name) --- by_student: Dict[str, Dict[str, Any]] = {} def _add_hit(student_key: str, display_name: str, meta: Dict[str, Any], topic: str, source: str, score: float, field: Optional[str] = None, snippet: Optional[str] = None): row = by_student.setdefault(student_key, { "student": display_name, "email": meta.get("email"), "project_title": meta.get("project_title") or meta.get("title") or meta.get("project"), "metadata": meta, "coverage": 0, "score": 0.0, "evidence": [], }) # add evidence row["evidence"].append({ "topic": topic, "source": source, "field": field, "snippet": snippet, "score": score }) # recompute coverage & score covered = {ev["topic"] for ev in row["evidence"]} row["coverage"] = len(covered) # use max score as a simple aggregate; you can switch to average if you prefer row["score"] = max([ev.get("score") or 0.0 for ev in row["evidence"]]) # --- 1) METADATA PASS --- for topic in topic_list: for skey, meta in _METADATA_BY_STUDENT.items(): disp = meta.get("name") or skey matched = False for f in fields: v = meta.get(f) if isinstance(v, str) and _ci_contains(v, topic): _add_hit(skey, disp, meta, topic, source="metadata", score=1.0, field=f, snippet=v[:240]) matched = True break elif isinstance(v, list) and any(isinstance(x, str) and _ci_contains(x, topic) for x in v): # join just for snippet preview joined = ", ".join([x for x in v if isinstance(x, str)])[:240] _add_hit(skey, disp, meta, topic, source="metadata", score=1.0, field=f, snippet=joined) matched = True break # if matched, we already recorded; continue to next student # If we already have enough diverse candidates, great—but still do RAG to improve coverage if some topics lack hits have_students = set(by_student.keys()) # --- Which topics are still weak in coverage from metadata alone? --- def _topic_covered(t: str) -> bool: for row in by_student.values(): if any(ev["topic"] == t for ev in row["evidence"] if ev["source"] == "metadata"): return True return False topics_needing_rag = [t for t in topic_list if not _topic_covered(t)] # --- 2) RAG PASS (only for topics that missed in metadata) --- if topics_needing_rag: try: idx = _pine_index() except Exception as e: # If Pinecone not configured, return what we have return { "requested_topics": topic_list, "candidates": sorted( list(by_student.values()), key=lambda r: (-r["coverage"], -r["score"], (r["student"] or ""))), "_receipt": _receipt() | {"warning": f"Skipping RAG: {e}"} } for topic in topics_needing_rag: try: vec = _embed([topic])[0] resp = idx.query( vector=vec, top_k=max(1, min(int(top_k_per_topic), 50)), include_metadata=True, namespace=namespace or None, ) except Exception as e: # continue with others continue for m in getattr(resp, "matches", []) or []: md = getattr(m, "metadata", None) or {} score = float(getattr(m, "score", 0.0) or 0.0) # Try to resolve student name from metadata keys commonly used cand_name = md.get("student") or md.get("student_name") or md.get("name") or md.get("author") # Else attempt to infer student from a source path like ".../students/Ahmed/summary.txt" if not cand_name: src = md.get("source_path") or md.get("path") or md.get("file") or "" # heuristic: find a segment after "students" parts = [p for p in str(src).replace("\\", "/").split("/") if p] if "students" in parts: idx_students = parts.index("students") if idx_students + 1 < len(parts): cand_name = parts[idx_students + 1] if not cand_name: # cannot attribute to a student; skip continue skey = _norm(cand_name).lower() # prefer known metadata if we already have it meta = _METADATA_BY_STUDENT.get(skey, {"name": cand_name}) snippet = md.get("text") or md.get("chunk") or md.get("content") or "" if isinstance(snippet, str): snippet = snippet.strip()[:240] else: snippet = str(snippet)[:240] _add_hit(skey, meta.get("name") or cand_name, meta, topic, source="rag", score=score, field=None, snippet=snippet) # --- RANK & SELECT --- pool = list(by_student.values()) # primary: more topic coverage; secondary: higher score; tertiary: name pool.sort(key=lambda r: (-r["coverage"], -float(r.get("score") or 0.0), (r.get("student") or ""))) selected: List[Dict[str, Any]] = [] used_projects: set[str] = set() for row in pool: if ensure_distinct_projects: ptitle = (row.get("project_title") or "").strip().lower() if ptitle and ptitle in used_projects: # same project title already present → skip to keep them distinct continue if ptitle: used_projects.add(ptitle) selected.append(row) if len(selected) >= max(1, int(n)): break return { "requested_topics": topic_list, "candidates": selected, "_receipt": _receipt() } # ---- HTTP runner for HF Space / local run ---- if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) # HTTP, path /mcp mcp.run(transport="http", host="0.0.0.0", port=port, path="/mcp")