Spaces:
Running
Running
| """ | |
| Court Session Manager. | |
| Single source of truth for everything that happens in a moot court session. | |
| Every agent reads from and writes to the session object. | |
| Sessions persist to HuggingFace Dataset for durability across container restarts. | |
| Session lifecycle: | |
| created → briefing → rounds → cross_examination → closing → completed | |
| WHY store to HF Dataset? | |
| HF Spaces containers are ephemeral. Without durable storage, all session | |
| data is lost on restart. HF Dataset API gives us free durable storage | |
| using the same HF_TOKEN already in the Space secrets. | |
| """ | |
| import os | |
| import json | |
| import uuid | |
| import logging | |
| from datetime import datetime, timezone | |
| from typing import Optional, Dict, List, Any | |
| from dataclasses import dataclass, field, asdict | |
| logger = logging.getLogger(__name__) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| SESSIONS_REPO = "CaffeinatedCoding/nyayasetu-court-sessions" | |
| # ── In-memory session store ──────────────────────────────────── | |
| # Primary store during runtime. HF Dataset is the durable backup. | |
| _sessions: Dict[str, Dict] = {} | |
| # ── Data structures ──────────────────────────────────────────── | |
| class TranscriptEntry: | |
| """A single entry in the court transcript.""" | |
| speaker: str # JUDGE | OPPOSING_COUNSEL | REGISTRAR | PETITIONER | RESPONDENT | |
| role_label: str # Display label e.g. "HON'BLE COURT", "RESPONDENT'S COUNSEL" | |
| content: str # The actual text | |
| round_number: int # Which round this belongs to | |
| phase: str # briefing | argument | cross_examination | closing | |
| timestamp: str # ISO timestamp | |
| entry_type: str # argument | question | observation | objection | ruling | document | trap | |
| metadata: Dict = field(default_factory=dict) # extra data e.g. trap_type, precedents_cited | |
| class Concession: | |
| """A concession made by the user during the session.""" | |
| round_number: int | |
| exact_quote: str # The exact text where concession was made | |
| legal_significance: str # What opposing counsel can do with this | |
| exploited: bool = False # Has opposing counsel used this yet | |
| class TrapEvent: | |
| """A trap set by opposing counsel.""" | |
| round_number: int | |
| trap_type: str # admission_trap | precedent_trap | inconsistency_trap | |
| trap_text: str # What opposing counsel said to set the trap | |
| user_fell_in: bool # Whether user fell into the trap | |
| user_response: str = "" # What user said in response | |
| class CourtSession: | |
| """Complete court session state.""" | |
| # Identity | |
| session_id: str | |
| created_at: str | |
| updated_at: str | |
| # Case | |
| case_title: str | |
| user_side: str # petitioner | respondent | |
| user_client: str | |
| opposing_party: str | |
| legal_issues: List[str] | |
| brief_facts: str | |
| jurisdiction: str # supreme_court | high_court | district_court | |
| # Setup | |
| bench_composition: str # single | division | constitutional | |
| difficulty: str # moot | standard | adversarial | |
| session_length: str # brief | standard | extended | |
| show_trap_warnings: bool | |
| # Derived from research session import | |
| imported_from_session: Optional[str] # NyayaSetu research session ID | |
| case_brief: str # Generated case brief text | |
| retrieved_precedents: List[Dict] # Precedents from research session | |
| # Session progress | |
| phase: str # briefing | rounds | cross_examination | closing | completed | |
| current_round: int | |
| max_rounds: int # 3 | 5 | 8 | |
| # Transcript | |
| transcript: List[Dict] # List of TranscriptEntry as dicts | |
| # Tracking | |
| concessions: List[Dict] # List of Concession as dicts | |
| trap_events: List[Dict] # List of TrapEvent as dicts | |
| cited_precedents: List[str] # Judgment IDs cited during session | |
| documents_produced: List[Dict] # Documents generated during session | |
| # Arguments tracking for inconsistency detection | |
| user_arguments: List[Dict] # [{round, text, key_claims: []}] | |
| # Analysis (populated at end) | |
| analysis: Optional[Dict] | |
| outcome_prediction: Optional[str] | |
| performance_score: Optional[float] | |
| def create_session( | |
| case_title: str, | |
| user_side: str, | |
| user_client: str, | |
| opposing_party: str, | |
| legal_issues: List[str], | |
| brief_facts: str, | |
| jurisdiction: str, | |
| bench_composition: str, | |
| difficulty: str, | |
| session_length: str, | |
| show_trap_warnings: bool, | |
| imported_from_session: Optional[str] = None, | |
| case_brief: str = "", | |
| retrieved_precedents: Optional[List[Dict]] = None, | |
| ) -> str: | |
| """ | |
| Create a new court session. Returns session_id. | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| now = datetime.now(timezone.utc).isoformat() | |
| max_rounds_map = {"brief": 3, "standard": 5, "extended": 8} | |
| session = CourtSession( | |
| session_id=session_id, | |
| created_at=now, | |
| updated_at=now, | |
| case_title=case_title, | |
| user_side=user_side, | |
| user_client=user_client, | |
| opposing_party=opposing_party, | |
| legal_issues=legal_issues, | |
| brief_facts=brief_facts, | |
| jurisdiction=jurisdiction, | |
| bench_composition=bench_composition, | |
| difficulty=difficulty, | |
| session_length=session_length, | |
| show_trap_warnings=show_trap_warnings, | |
| imported_from_session=imported_from_session, | |
| case_brief=case_brief, | |
| retrieved_precedents=retrieved_precedents or [], | |
| phase="briefing", | |
| current_round=0, | |
| max_rounds=max_rounds_map.get(session_length, 5), | |
| transcript=[], | |
| concessions=[], | |
| trap_events=[], | |
| cited_precedents=[], | |
| documents_produced=[], | |
| user_arguments=[], | |
| analysis=None, | |
| outcome_prediction=None, | |
| performance_score=None, | |
| ) | |
| _sessions[session_id] = asdict(session) | |
| logger.info(f"Session created: {session_id} | {case_title}") | |
| return session_id | |
| def get_session(session_id: str) -> Optional[Dict]: | |
| """Get session from memory. Returns None if not found.""" | |
| return _sessions.get(session_id) | |
| def update_session(session_id: str, updates: Dict) -> bool: | |
| """Apply updates to session and persist to HF.""" | |
| if session_id not in _sessions: | |
| logger.warning(f"Session not found: {session_id}") | |
| return False | |
| _sessions[session_id].update(updates) | |
| _sessions[session_id]["updated_at"] = datetime.now(timezone.utc).isoformat() | |
| # Async persist to HF Dataset | |
| _persist_session(session_id) | |
| return True | |
| def add_transcript_entry( | |
| session_id: str, | |
| speaker: str, | |
| role_label: str, | |
| content: str, | |
| entry_type: str = "argument", | |
| metadata: Optional[Dict] = None, | |
| ) -> bool: | |
| """Add a new entry to the session transcript.""" | |
| session = get_session(session_id) | |
| if not session: | |
| return False | |
| entry = asdict(TranscriptEntry( | |
| speaker=speaker, | |
| role_label=role_label, | |
| content=content, | |
| round_number=session["current_round"], | |
| phase=session["phase"], | |
| timestamp=datetime.now(timezone.utc).isoformat(), | |
| entry_type=entry_type, | |
| metadata=metadata or {}, | |
| )) | |
| session["transcript"].append(entry) | |
| session["updated_at"] = datetime.now(timezone.utc).isoformat() | |
| _persist_session(session_id) | |
| return True | |
| def add_concession( | |
| session_id: str, | |
| exact_quote: str, | |
| legal_significance: str, | |
| ) -> bool: | |
| """Record a concession made by the user.""" | |
| session = get_session(session_id) | |
| if not session: | |
| return False | |
| concession = asdict(Concession( | |
| round_number=session["current_round"], | |
| exact_quote=exact_quote, | |
| legal_significance=legal_significance, | |
| )) | |
| session["concessions"].append(concession) | |
| session["updated_at"] = datetime.now(timezone.utc).isoformat() | |
| logger.info(f"Concession recorded in session {session_id}: {exact_quote[:80]}") | |
| return True | |
| def add_trap_event( | |
| session_id: str, | |
| trap_type: str, | |
| trap_text: str, | |
| user_fell_in: bool = False, | |
| user_response: str = "", | |
| ) -> bool: | |
| """Record a trap event.""" | |
| session = get_session(session_id) | |
| if not session: | |
| return False | |
| trap = asdict(TrapEvent( | |
| round_number=session["current_round"], | |
| trap_type=trap_type, | |
| trap_text=trap_text, | |
| user_fell_in=user_fell_in, | |
| user_response=user_response, | |
| )) | |
| session["trap_events"].append(trap) | |
| session["updated_at"] = datetime.now(timezone.utc).isoformat() | |
| return True | |
| def add_user_argument( | |
| session_id: str, | |
| argument_text: str, | |
| key_claims: List[str], | |
| ) -> bool: | |
| """Track user's argument for inconsistency detection.""" | |
| session = get_session(session_id) | |
| if not session: | |
| return False | |
| session["user_arguments"].append({ | |
| "round": session["current_round"], | |
| "text": argument_text, | |
| "key_claims": key_claims, | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| }) | |
| return True | |
| def advance_phase(session_id: str) -> str: | |
| """ | |
| Move session to next phase. | |
| Returns new phase name. | |
| """ | |
| session = get_session(session_id) | |
| if not session: | |
| return "" | |
| phase_progression = { | |
| "briefing": "rounds", | |
| "rounds": "cross_examination", | |
| "cross_examination": "closing", | |
| "closing": "completed", | |
| } | |
| current = session["phase"] | |
| next_phase = phase_progression.get(current, "completed") | |
| update_session(session_id, {"phase": next_phase}) | |
| logger.info(f"Session {session_id} advanced: {current} → {next_phase}") | |
| return next_phase | |
| def advance_round(session_id: str) -> int: | |
| """Increment round counter. Returns new round number.""" | |
| session = get_session(session_id) | |
| if not session: | |
| return 0 | |
| new_round = session["current_round"] + 1 | |
| # Auto-advance phase when max rounds reached | |
| if new_round > session["max_rounds"] and session["phase"] == "rounds": | |
| advance_phase(session_id) | |
| update_session(session_id, {"current_round": new_round}) | |
| return new_round | |
| def get_all_sessions() -> List[Dict]: | |
| """Return all sessions, sorted by updated_at descending.""" | |
| sessions = list(_sessions.values()) | |
| return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True) | |
| def get_session_transcript_text(session_id: str) -> str: | |
| """ | |
| Return full transcript as formatted text for LLM consumption. | |
| Format matches real court transcript style. | |
| """ | |
| session = get_session(session_id) | |
| if not session: | |
| return "" | |
| lines = [ | |
| f"IN THE {session['jurisdiction'].upper().replace('_', ' ')}", | |
| f"Case: {session['case_title']}", | |
| f"Petitioner: {session['user_client'] if session['user_side'] == 'petitioner' else session['opposing_party']}", | |
| f"Respondent: {session['opposing_party'] if session['user_side'] == 'petitioner' else session['user_client']}", | |
| "", | |
| "PROCEEDINGS:", | |
| "", | |
| ] | |
| for entry in session["transcript"]: | |
| lines.append(f"{entry['role_label'].upper()}") | |
| lines.append(entry["content"]) | |
| lines.append("") | |
| return "\n".join(lines) | |
| def _persist_session(session_id: str): | |
| """ | |
| Persist session to HuggingFace Dataset. | |
| Fails silently — in-memory session is still valid. | |
| Non-critical: if HF upload fails, session continues working offline. | |
| """ | |
| if not HF_TOKEN: | |
| return | |
| try: | |
| from huggingface_hub import HfApi | |
| import threading | |
| def _upload(): | |
| try: | |
| api = HfApi(token=HF_TOKEN) | |
| session_data = json.dumps(_sessions[session_id], ensure_ascii=False) | |
| try: | |
| api.create_repo( | |
| repo_id=SESSIONS_REPO, | |
| repo_type="dataset", | |
| private=True, | |
| exist_ok=True | |
| ) | |
| except Exception as repo_err: | |
| logger.debug(f"Could not create/access HF repo: {repo_err}") | |
| api.upload_file( | |
| path_or_fileobj=session_data.encode(), | |
| path_in_repo=f"sessions/{session_id}.json", | |
| repo_id=SESSIONS_REPO, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| except Exception as upload_err: | |
| logger.debug(f"Session upload to HF failed (working offline): {upload_err}") | |
| # Run in background thread — never blocks the response | |
| thread = threading.Thread(target=_upload, daemon=True) | |
| thread.start() | |
| except Exception as e: | |
| logger.debug(f"Session persist setup failed (non-critical): {e}") | |
| def load_sessions_from_hf(): | |
| """ | |
| Load all sessions from HF Dataset on startup. | |
| Called once from api/main.py after download_models(). | |
| """ | |
| if not HF_TOKEN: | |
| logger.warning("No HF_TOKEN — sessions will not persist across restarts") | |
| return | |
| try: | |
| from huggingface_hub import HfApi, list_repo_files | |
| api = HfApi(token=HF_TOKEN) | |
| try: | |
| files = list(api.list_repo_files( | |
| repo_id=SESSIONS_REPO, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| )) | |
| except Exception: | |
| logger.info("No existing sessions on HF — starting fresh") | |
| return | |
| session_files = [f for f in files if f.startswith("sessions/") and f.endswith(".json")] | |
| loaded = 0 | |
| for filepath in session_files: | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| local_path = hf_hub_download( | |
| repo_id=SESSIONS_REPO, | |
| filename=filepath, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| with open(local_path) as f: | |
| session_data = json.load(f) | |
| session_id = session_data.get("session_id") | |
| if session_id: | |
| _sessions[session_id] = session_data | |
| loaded += 1 | |
| except Exception: | |
| continue | |
| logger.info(f"Loaded {loaded} sessions from HF Dataset") | |
| except Exception as e: | |
| logger.warning(f"Session load from HF failed (non-critical): {e}") | |