Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import uuid | |
| import time | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from fastmcp import FastMCP | |
| # -------- Config: data root resolution (robust, works locally & on HF) -------- | |
| def resolve_data_root() -> Path: | |
| # 1) Env var wins if set | |
| env = os.getenv("DATA_ROOT") | |
| if env: | |
| p = Path(env).expanduser().resolve() | |
| if p.exists(): | |
| return p | |
| # 2) Try common repo-relative locations (first that exists wins) | |
| candidates = [ | |
| Path.cwd() / "TRAINING DATA", # your exact folder with space | |
| Path.cwd() / "training_data", | |
| Path.cwd() / "data", | |
| ] | |
| for c in candidates: | |
| if c.exists(): | |
| return c.resolve() | |
| # 3) Last resort: create ./TRAINING DATA to avoid crashes | |
| fallback = Path.cwd() / "TRAINING DATA" | |
| fallback.mkdir(parents=True, exist_ok=True) | |
| return fallback.resolve() | |
| DATA_ROOT: Path = resolve_data_root() | |
| STUDENTS_DIR: Path = DATA_ROOT / "students" | |
| # -------- Small utilities -------- | |
| def _receipt() -> Dict[str, Any]: | |
| return { | |
| "tool_used": True, | |
| "server_time": datetime.now(timezone.utc).isoformat(), | |
| "request_id": str(uuid.uuid4()), | |
| "data_root": str(DATA_ROOT), | |
| } | |
| def _norm(s: Optional[str]) -> str: | |
| return (s or "").strip() | |
| def _ci_contains(hay: Optional[str], needle: Optional[str]) -> bool: | |
| return _norm(needle).lower() in _norm(hay).lower() | |
| # -------- Load metadata (from DATA_ROOT/students/*/metadata.json and/or DATA_ROOT/metadata.json) -------- | |
| # In-memory index: { student_name_lower: {"name":..., "email":..., ...} } | |
| _METADATA_BY_STUDENT: Dict[str, Dict[str, Any]] = {} | |
| def _load_all_metadata() -> None: | |
| global _METADATA_BY_STUDENT | |
| _METADATA_BY_STUDENT = {} | |
| # A) Per-student folders | |
| if STUDENTS_DIR.exists(): | |
| for student_dir in sorted(STUDENTS_DIR.iterdir()): | |
| if not student_dir.is_dir(): | |
| continue | |
| meta_file = student_dir / "metadata.json" | |
| if meta_file.exists(): | |
| try: | |
| data = json.loads(meta_file.read_text(encoding="utf-8")) | |
| except Exception: | |
| continue | |
| # Ensure a 'name' field; default to directory name | |
| name = data.get("name") or student_dir.name | |
| data["name"] = name | |
| data["__path"] = str(meta_file) | |
| _METADATA_BY_STUDENT[_norm(name).lower()] = data | |
| # B) Optional top-level metadata.json (may contain a list or a dict of students) | |
| top_meta = DATA_ROOT / "metadata.json" | |
| if top_meta.exists(): | |
| try: | |
| blob = json.loads(top_meta.read_text(encoding="utf-8")) | |
| # Accept either: | |
| # - {"students":[{...},{...}]} | |
| # - [{"name":..., ...}, ...] | |
| # - {"<name>": {...}, ...} | |
| candidates: List[Dict[str, Any]] = [] | |
| if isinstance(blob, dict) and "students" in blob and isinstance(blob["students"], list): | |
| candidates = blob["students"] | |
| elif isinstance(blob, list): | |
| candidates = blob | |
| elif isinstance(blob, dict): | |
| for k, v in blob.items(): | |
| if isinstance(v, dict): | |
| v.setdefault("name", k) | |
| candidates.append(v) | |
| for data in candidates: | |
| name = data.get("name") | |
| if not name: | |
| continue | |
| data["__path"] = str(top_meta) | |
| _METADATA_BY_STUDENT[_norm(name).lower()] = data | |
| except Exception: | |
| pass | |
| # Initial load | |
| _load_all_metadata() | |
| # -------- OpenAI embeddings (for Pinecone RAG) -------- | |
| from openai import OpenAI | |
| _openai_client: Optional[OpenAI] = None | |
| _EMBED_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") | |
| def _embed(texts: List[str]) -> List[List[float]]: | |
| """ | |
| Embed a batch of strings using OpenAI embeddings. | |
| """ | |
| global _openai_client | |
| if _openai_client is None: | |
| _openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # OpenAI Python SDK v1 returns .data with embeddings in order | |
| resp = _openai_client.embeddings.create(model=_EMBED_MODEL, input=texts) | |
| return [d.embedding for d in resp.data] | |
| # -------- Pinecone client -------- | |
| from pinecone import Pinecone as _Pinecone | |
| _pine: Optional[_Pinecone] = None | |
| def _pc() -> _Pinecone: | |
| global _pine | |
| if _pine is None: | |
| api_key = os.getenv("PINECONE_API_KEY") | |
| if not api_key: | |
| raise RuntimeError("PINECONE_API_KEY not set") | |
| _pine = _Pinecone(api_key=api_key) | |
| return _pine | |
| def _pine_index(): | |
| pc = _pc() | |
| index_name = os.getenv("PINECONE_INDEX_NAME") | |
| if not index_name: | |
| raise RuntimeError("PINECONE_INDEX_NAME not set") | |
| return pc.Index(index_name) | |
| # -------- MCP server -------- | |
| mcp = FastMCP("ProjectRAGServer") | |
| def add(a: int, b: int) -> Dict[str, Any]: | |
| """Add two numbers.""" | |
| return {"result": int(a) + int(b), "_receipt": _receipt()} | |
| def list_students() -> Dict[str, Any]: | |
| """Return all known student names.""" | |
| names = sorted(v.get("name") for v in _METADATA_BY_STUDENT.values() if v.get("name")) | |
| return {"students": names, "count": len(names), "_receipt": _receipt()} | |
| def get_student_metadata(name: str) -> Dict[str, Any]: | |
| """Return full metadata for a student by name (case-insensitive).""" | |
| key = _norm(name).lower() | |
| data = _METADATA_BY_STUDENT.get(key) | |
| if not data: | |
| return {"error": f"Student '{name}' not found.", "_receipt": _receipt()} | |
| return {"metadata": data, "_receipt": _receipt()} | |
| def get_student_email(name: str) -> Dict[str, Any]: | |
| """Return the email address for a student by name.""" | |
| key = _norm(name).lower() | |
| data = _METADATA_BY_STUDENT.get(key) | |
| if not data: | |
| return {"error": f"Student '{name}' not found.", "_receipt": _receipt()} | |
| email = data.get("email") | |
| if not email: | |
| return {"error": f"No email in metadata for '{data.get('name')}'.", "_receipt": _receipt()} | |
| return {"name": data.get("name"), "email": email, "_receipt": _receipt()} | |
| def search_student_by_field(field: str, value: str) -> Dict[str, Any]: | |
| """ | |
| Case-insensitive contains() search across any metadata field. | |
| Example: field='department', value='Computer' | |
| """ | |
| f = _norm(field) | |
| val = _norm(value) | |
| if not f: | |
| return {"error": "Field must be provided.", "_receipt": _receipt()} | |
| matches: List[Dict[str, Any]] = [] | |
| for meta in _METADATA_BY_STUDENT.values(): | |
| if f not in meta: | |
| continue | |
| v = meta.get(f) | |
| # Allow both strings and list-of-strings | |
| if isinstance(v, str) and _ci_contains(v, val): | |
| matches.append({"name": meta.get("name"), "match_value": v, "metadata": meta}) | |
| elif isinstance(v, list) and any(_ci_contains(x, val) for x in v if isinstance(x, str)): | |
| matches.append({"name": meta.get("name"), "match_value": v, "metadata": meta}) | |
| return {"matches": matches, "count": len(matches), "_receipt": _receipt()} | |
| def reload_metadata() -> Dict[str, Any]: | |
| """Reload metadata from disk (useful after updating files).""" | |
| t0 = time.time() | |
| _load_all_metadata() | |
| dt = round((time.time() - t0) * 1000.0, 2) | |
| return {"ok": True, "students": len(_METADATA_BY_STUDENT), "ms": dt, "_receipt": _receipt()} | |
| def search_rag(query: str, top_k: int = 3, namespace: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| Semantic search over your Pinecone index using OpenAI embeddings. | |
| Returns top_k matches with metadata. | |
| Env required: OPENAI_API_KEY, PINECONE_API_KEY, PINECONE_INDEX_NAME | |
| Optional: OPENAI_EMBEDDING_MODEL (default: text-embedding-3-small) | |
| """ | |
| q = _norm(query) | |
| if not q: | |
| return {"error": "Query must not be empty.", "_receipt": _receipt()} | |
| try: | |
| vec = _embed([q])[0] | |
| except Exception as e: | |
| return {"error": f"Embedding failed: {e}", "_receipt": _receipt()} | |
| try: | |
| idx = _pine_index() | |
| except Exception as e: | |
| return {"error": f"Pinecone init failed: {e}", "_receipt": _receipt()} | |
| try: | |
| resp = idx.query( | |
| vector=vec, | |
| top_k=max(1, min(int(top_k), 50)), | |
| include_metadata=True, | |
| namespace=namespace or None, | |
| ) | |
| except Exception as e: | |
| return {"error": f"Pinecone query failed: {e}", "_receipt": _receipt()} | |
| results = [] | |
| for m in getattr(resp, "matches", []) or []: | |
| results.append({ | |
| "id": getattr(m, "id", None), | |
| "score": getattr(m, "score", None), | |
| "metadata": getattr(m, "metadata", None), | |
| }) | |
| return { | |
| "query": q, | |
| "model": _EMBED_MODEL, | |
| "top_k": top_k, | |
| "results": results, | |
| "_receipt": _receipt(), | |
| } | |
| def find_candidates( | |
| topics: List[str] | str, | |
| n: int = 2, | |
| prefer_fields: Optional[List[str]] = None, | |
| ensure_distinct_projects: bool = True, | |
| top_k_per_topic: int = 5, | |
| namespace: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Find N student candidates for the given topic(s). | |
| 1) Try metadata search first across prefer_fields (default sensible set). | |
| 2) If still short, fallback to Pinecone RAG per topic. | |
| 3) Aggregate by student with evidence; enforce distinct projects if requested. | |
| Args: | |
| topics: A string or list of topic strings (e.g., "reinforcement learning" or ["RL","traffic"]). | |
| n: number of distinct students to return. | |
| prefer_fields: metadata fields to search (case-insensitive contains). Defaults to | |
| ["research_interests", "skills", "keywords", "department", "project_title", "summary"]. | |
| ensure_distinct_projects: if True, avoid picking >1 student with the same project_title. | |
| top_k_per_topic: RAG candidates to pull per topic. | |
| namespace: optional Pinecone namespace. | |
| Returns: | |
| { | |
| "requested_topics": [ ... ], | |
| "candidates": [ | |
| { | |
| "student": "Full Name", | |
| "email": "name@example.com", | |
| "project_title": "...", | |
| "coverage": 2, # how many topics this student matched | |
| "score": 0.74, # aggregate/max score from RAG/metadata (metadata hits use 1.0) | |
| "evidence": [ | |
| { "topic":"...", "source":"metadata|rag", "field":"...", "snippet":"...", "score":1.0 }, | |
| ... | |
| ], | |
| "metadata": {...} # full student metadata | |
| }, | |
| ... | |
| ], | |
| "_receipt": {...} | |
| } | |
| """ | |
| # --- normalize topics --- | |
| if isinstance(topics, str): | |
| topic_list = [_norm(topics)] | |
| else: | |
| topic_list = [_norm(t) for t in topics if _norm(t)] | |
| topic_list = [t for t in topic_list if t] | |
| if not topic_list: | |
| return {"error": "No non-empty topics provided.", "_receipt": _receipt()} | |
| # --- fields to prefer in metadata search --- | |
| fields = prefer_fields or [ | |
| "research_interests", "skills", "keywords", "department", | |
| "project_title", "summary", "areas", "tags" | |
| ] | |
| # --- accumulator keyed by student key (lowercased name) --- | |
| by_student: Dict[str, Dict[str, Any]] = {} | |
| def _add_hit(student_key: str, display_name: str, meta: Dict[str, Any], topic: str, | |
| source: str, score: float, field: Optional[str] = None, snippet: Optional[str] = None): | |
| row = by_student.setdefault(student_key, { | |
| "student": display_name, | |
| "email": meta.get("email"), | |
| "project_title": meta.get("project_title") or meta.get("title") or meta.get("project"), | |
| "metadata": meta, | |
| "coverage": 0, | |
| "score": 0.0, | |
| "evidence": [], | |
| }) | |
| # add evidence | |
| row["evidence"].append({ | |
| "topic": topic, | |
| "source": source, | |
| "field": field, | |
| "snippet": snippet, | |
| "score": score | |
| }) | |
| # recompute coverage & score | |
| covered = {ev["topic"] for ev in row["evidence"]} | |
| row["coverage"] = len(covered) | |
| # use max score as a simple aggregate; you can switch to average if you prefer | |
| row["score"] = max([ev.get("score") or 0.0 for ev in row["evidence"]]) | |
| # --- 1) METADATA PASS --- | |
| for topic in topic_list: | |
| for skey, meta in _METADATA_BY_STUDENT.items(): | |
| disp = meta.get("name") or skey | |
| matched = False | |
| for f in fields: | |
| v = meta.get(f) | |
| if isinstance(v, str) and _ci_contains(v, topic): | |
| _add_hit(skey, disp, meta, topic, source="metadata", score=1.0, field=f, snippet=v[:240]) | |
| matched = True | |
| break | |
| elif isinstance(v, list) and any(isinstance(x, str) and _ci_contains(x, topic) for x in v): | |
| # join just for snippet preview | |
| joined = ", ".join([x for x in v if isinstance(x, str)])[:240] | |
| _add_hit(skey, disp, meta, topic, source="metadata", score=1.0, field=f, snippet=joined) | |
| matched = True | |
| break | |
| # if matched, we already recorded; continue to next student | |
| # If we already have enough diverse candidates, great—but still do RAG to improve coverage if some topics lack hits | |
| have_students = set(by_student.keys()) | |
| # --- Which topics are still weak in coverage from metadata alone? --- | |
| def _topic_covered(t: str) -> bool: | |
| for row in by_student.values(): | |
| if any(ev["topic"] == t for ev in row["evidence"] if ev["source"] == "metadata"): | |
| return True | |
| return False | |
| topics_needing_rag = [t for t in topic_list if not _topic_covered(t)] | |
| # --- 2) RAG PASS (only for topics that missed in metadata) --- | |
| if topics_needing_rag: | |
| try: | |
| idx = _pine_index() | |
| except Exception as e: | |
| # If Pinecone not configured, return what we have | |
| return { | |
| "requested_topics": topic_list, | |
| "candidates": sorted( | |
| list(by_student.values()), | |
| key=lambda r: (-r["coverage"], -r["score"], (r["student"] or ""))), | |
| "_receipt": _receipt() | {"warning": f"Skipping RAG: {e}"} | |
| } | |
| for topic in topics_needing_rag: | |
| try: | |
| vec = _embed([topic])[0] | |
| resp = idx.query( | |
| vector=vec, | |
| top_k=max(1, min(int(top_k_per_topic), 50)), | |
| include_metadata=True, | |
| namespace=namespace or None, | |
| ) | |
| except Exception as e: | |
| # continue with others | |
| continue | |
| for m in getattr(resp, "matches", []) or []: | |
| md = getattr(m, "metadata", None) or {} | |
| score = float(getattr(m, "score", 0.0) or 0.0) | |
| # Try to resolve student name from metadata keys commonly used | |
| cand_name = md.get("student") or md.get("student_name") or md.get("name") or md.get("author") | |
| # Else attempt to infer student from a source path like ".../students/Ahmed/summary.txt" | |
| if not cand_name: | |
| src = md.get("source_path") or md.get("path") or md.get("file") or "" | |
| # heuristic: find a segment after "students" | |
| parts = [p for p in str(src).replace("\\", "/").split("/") if p] | |
| if "students" in parts: | |
| idx_students = parts.index("students") | |
| if idx_students + 1 < len(parts): | |
| cand_name = parts[idx_students + 1] | |
| if not cand_name: | |
| # cannot attribute to a student; skip | |
| continue | |
| skey = _norm(cand_name).lower() | |
| # prefer known metadata if we already have it | |
| meta = _METADATA_BY_STUDENT.get(skey, {"name": cand_name}) | |
| snippet = md.get("text") or md.get("chunk") or md.get("content") or "" | |
| if isinstance(snippet, str): | |
| snippet = snippet.strip()[:240] | |
| else: | |
| snippet = str(snippet)[:240] | |
| _add_hit(skey, meta.get("name") or cand_name, meta, topic, source="rag", score=score, field=None, snippet=snippet) | |
| # --- RANK & SELECT --- | |
| pool = list(by_student.values()) | |
| # primary: more topic coverage; secondary: higher score; tertiary: name | |
| pool.sort(key=lambda r: (-r["coverage"], -float(r.get("score") or 0.0), (r.get("student") or ""))) | |
| selected: List[Dict[str, Any]] = [] | |
| used_projects: set[str] = set() | |
| for row in pool: | |
| if ensure_distinct_projects: | |
| ptitle = (row.get("project_title") or "").strip().lower() | |
| if ptitle and ptitle in used_projects: | |
| # same project title already present → skip to keep them distinct | |
| continue | |
| if ptitle: | |
| used_projects.add(ptitle) | |
| selected.append(row) | |
| if len(selected) >= max(1, int(n)): | |
| break | |
| return { | |
| "requested_topics": topic_list, | |
| "candidates": selected, | |
| "_receipt": _receipt() | |
| } | |
| # ---- HTTP runner for HF Space / local run ---- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| # HTTP, path /mcp | |
| mcp.run(transport="http", host="0.0.0.0", port=port, path="/mcp") | |