CaffeinatedCoding commited on
Commit
8efa523
·
verified ·
1 Parent(s): e3240a1

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. api/main.py +36 -5
  2. src/citation_graph.py +193 -0
api/main.py CHANGED
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
21
 
22
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
 
 
24
  def download_models():
25
-
26
  hf_token = os.getenv("HF_TOKEN")
27
  if not hf_token:
28
  logger.warning("HF_TOKEN not set — skipping model download.")
@@ -30,35 +30,61 @@ def download_models():
30
  try:
31
  from huggingface_hub import snapshot_download, hf_hub_download
32
  repo_id = "CaffeinatedCoding/nyayasetu-models"
 
33
  if not os.path.exists("models/ner_model"):
34
  logger.info("Downloading NER model...")
35
- snapshot_download(repo_id=repo_id, repo_type="model", allow_patterns="ner_model/*", local_dir="models", token=hf_token)
 
 
 
36
  logger.info("NER model downloaded")
37
  else:
38
  logger.info("NER model already exists")
 
39
  if not os.path.exists("models/faiss_index/index.faiss"):
40
  logger.info("Downloading FAISS index...")
41
  os.makedirs("models/faiss_index", exist_ok=True)
42
- hf_hub_download(repo_id=repo_id, filename="faiss_index/index.faiss", repo_type="model", local_dir="models", token=hf_token)
43
- hf_hub_download(repo_id=repo_id, filename="faiss_index/chunk_metadata.jsonl", repo_type="model", local_dir="models", token=hf_token)
 
 
44
  logger.info("FAISS index downloaded")
45
  else:
46
  logger.info("FAISS index already exists")
 
47
  if not os.path.exists("data/parent_judgments.jsonl"):
48
  logger.info("Downloading parent judgments...")
49
  os.makedirs("data", exist_ok=True)
50
- hf_hub_download(repo_id=repo_id, filename="parent_judgments.jsonl", repo_type="model", local_dir="data", token=hf_token)
 
51
  logger.info("Parent judgments downloaded")
52
  else:
53
  logger.info("Parent judgments already exist")
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  except Exception as e:
55
  logger.error(f"Model download failed: {e}")
56
 
 
57
  download_models()
58
 
59
  from src.ner import load_ner_model
60
  load_ner_model()
61
 
 
 
 
62
  AGENT_VERSION = os.getenv("AGENT_VERSION", "v2")
63
 
64
  if AGENT_VERSION == "v2":
@@ -77,10 +103,12 @@ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], all
77
  if os.path.exists("frontend"):
78
  app.mount("/static", StaticFiles(directory="frontend"), name="static")
79
 
 
80
  class QueryRequest(BaseModel):
81
  query: str
82
  session_id: Optional[str] = None
83
 
 
84
  class QueryResponse(BaseModel):
85
  query: str
86
  answer: str
@@ -92,16 +120,19 @@ class QueryResponse(BaseModel):
92
  truncated: bool
93
  latency_ms: float
94
 
 
95
  @app.get("/")
96
  def serve_frontend():
97
  if os.path.exists("frontend/index.html"):
98
  return FileResponse("frontend/index.html")
99
  return {"name": "NyayaSetu", "version": "2.0.0", "agent": AGENT_VERSION}
100
 
 
101
  @app.get("/health")
102
  def health():
103
  return {"status": "ok", "service": "NyayaSetu", "version": "2.0.0", "agent": AGENT_VERSION}
104
 
 
105
  @app.post("/query", response_model=QueryResponse)
106
  def query(request: QueryRequest):
107
  if not request.query.strip():
 
21
 
22
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
 
24
+
25
  def download_models():
 
26
  hf_token = os.getenv("HF_TOKEN")
27
  if not hf_token:
28
  logger.warning("HF_TOKEN not set — skipping model download.")
 
30
  try:
31
  from huggingface_hub import snapshot_download, hf_hub_download
32
  repo_id = "CaffeinatedCoding/nyayasetu-models"
33
+
34
  if not os.path.exists("models/ner_model"):
35
  logger.info("Downloading NER model...")
36
+ snapshot_download(
37
+ repo_id=repo_id, repo_type="model",
38
+ allow_patterns="ner_model/*", local_dir="models", token=hf_token
39
+ )
40
  logger.info("NER model downloaded")
41
  else:
42
  logger.info("NER model already exists")
43
+
44
  if not os.path.exists("models/faiss_index/index.faiss"):
45
  logger.info("Downloading FAISS index...")
46
  os.makedirs("models/faiss_index", exist_ok=True)
47
+ hf_hub_download(repo_id=repo_id, filename="faiss_index/index.faiss",
48
+ repo_type="model", local_dir="models", token=hf_token)
49
+ hf_hub_download(repo_id=repo_id, filename="faiss_index/chunk_metadata.jsonl",
50
+ repo_type="model", local_dir="models", token=hf_token)
51
  logger.info("FAISS index downloaded")
52
  else:
53
  logger.info("FAISS index already exists")
54
+
55
  if not os.path.exists("data/parent_judgments.jsonl"):
56
  logger.info("Downloading parent judgments...")
57
  os.makedirs("data", exist_ok=True)
58
+ hf_hub_download(repo_id=repo_id, filename="parent_judgments.jsonl",
59
+ repo_type="model", local_dir="data", token=hf_token)
60
  logger.info("Parent judgments downloaded")
61
  else:
62
  logger.info("Parent judgments already exist")
63
+
64
+ # Download citation graph artifacts — only if Kaggle run has completed
65
+ os.makedirs("data", exist_ok=True)
66
+ for fname in ["citation_graph.json", "reverse_citation_graph.json", "title_to_id.json"]:
67
+ if not os.path.exists(f"data/{fname}"):
68
+ logger.info(f"Downloading {fname}...")
69
+ try:
70
+ hf_hub_download(repo_id=repo_id, filename=fname,
71
+ repo_type="model", local_dir="data", token=hf_token)
72
+ logger.info(f"{fname} downloaded")
73
+ except Exception as fe:
74
+ logger.warning(f"{fname} not on Hub yet — skipping: {fe}")
75
+
76
  except Exception as e:
77
  logger.error(f"Model download failed: {e}")
78
 
79
+
80
  download_models()
81
 
82
  from src.ner import load_ner_model
83
  load_ner_model()
84
 
85
+ from src.citation_graph import load_citation_graph
86
+ load_citation_graph()
87
+
88
  AGENT_VERSION = os.getenv("AGENT_VERSION", "v2")
89
 
90
  if AGENT_VERSION == "v2":
 
103
  if os.path.exists("frontend"):
104
  app.mount("/static", StaticFiles(directory="frontend"), name="static")
105
 
106
+
107
  class QueryRequest(BaseModel):
108
  query: str
109
  session_id: Optional[str] = None
110
 
111
+
112
  class QueryResponse(BaseModel):
113
  query: str
114
  answer: str
 
120
  truncated: bool
121
  latency_ms: float
122
 
123
+
124
  @app.get("/")
125
  def serve_frontend():
126
  if os.path.exists("frontend/index.html"):
127
  return FileResponse("frontend/index.html")
128
  return {"name": "NyayaSetu", "version": "2.0.0", "agent": AGENT_VERSION}
129
 
130
+
131
  @app.get("/health")
132
  def health():
133
  return {"status": "ok", "service": "NyayaSetu", "version": "2.0.0", "agent": AGENT_VERSION}
134
 
135
+
136
  @app.post("/query", response_model=QueryResponse)
137
  def query(request: QueryRequest):
138
  if not request.query.strip():
src/citation_graph.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Precedent Chain Builder — Runtime Module.
3
+
4
+ Loads citation graph built offline by preprocessing/build_citation_graph.py.
5
+ At query time, enriches retrieved chunks with cited predecessor judgments.
6
+
7
+ WHY:
8
+ Indian SC judgments build on each other. A 1984 judgment establishing
9
+ a key principle was itself built on a 1971 judgment. Showing the user
10
+ the reasoning chain across cases makes NyayaSetu feel like a legal
11
+ researcher, not a search engine.
12
+
13
+ The graph is loaded once at startup and kept in memory.
14
+ Lookup is O(1) dict access — negligible runtime cost.
15
+ """
16
+
17
+ import os
18
+ import json
19
+ import re
20
+ import logging
21
+ from typing import List, Dict, Optional
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # ── Graph store ───────────────────────────────────────────
26
+ _graph = {} # judgment_id -> [citation_strings]
27
+ _reverse_graph = {} # citation_string -> [judgment_ids]
28
+ _title_to_id = {} # normalised_title -> judgment_id
29
+ _parent_store = {} # judgment_id -> text (loaded from parent_judgments.jsonl)
30
+ _loaded = False
31
+
32
+
33
+ def load_citation_graph(
34
+ graph_path: str = "data/citation_graph.json",
35
+ reverse_path: str = "data/reverse_citation_graph.json",
36
+ title_path: str = "data/title_to_id.json",
37
+ parent_path: str = "data/parent_judgments.jsonl"
38
+ ):
39
+ """
40
+ Load all citation graph artifacts once at startup.
41
+ Call from api/main.py after download_models().
42
+ Fails gracefully if files not found.
43
+ """
44
+ global _graph, _reverse_graph, _title_to_id, _parent_store, _loaded
45
+
46
+ try:
47
+ if os.path.exists(graph_path):
48
+ with open(graph_path) as f:
49
+ _graph = json.load(f)
50
+ logger.info(f"Citation graph loaded: {len(_graph)} judgments")
51
+ else:
52
+ logger.warning(f"Citation graph not found at {graph_path}")
53
+
54
+ if os.path.exists(reverse_path):
55
+ with open(reverse_path) as f:
56
+ _reverse_graph = json.load(f)
57
+ logger.info(f"Reverse citation graph loaded: {len(_reverse_graph)} citations")
58
+
59
+ if os.path.exists(title_path):
60
+ with open(title_path) as f:
61
+ _title_to_id = json.load(f)
62
+ logger.info(f"Title index loaded: {len(_title_to_id)} titles")
63
+
64
+ # Load parent judgments for text retrieval
65
+ if os.path.exists(parent_path):
66
+ with open(parent_path) as f:
67
+ for line in f:
68
+ line = line.strip()
69
+ if not line:
70
+ continue
71
+ try:
72
+ j = json.loads(line)
73
+ jid = j.get("judgment_id", "")
74
+ if jid:
75
+ _parent_store[jid] = j.get("text", "")
76
+ except Exception:
77
+ continue
78
+ logger.info(f"Parent store loaded: {len(_parent_store)} judgments")
79
+
80
+ _loaded = True
81
+
82
+ except Exception as e:
83
+ logger.error(f"Citation graph load failed: {e}. Precedent chain disabled.")
84
+ _loaded = False
85
+
86
+
87
+ def _resolve_citation_to_judgment(citation_string: str) -> Optional[str]:
88
+ """
89
+ Try to match a citation string to a judgment_id.
90
+ Uses multiple strategies in order of reliability.
91
+ """
92
+ if not citation_string:
93
+ return None
94
+
95
+ # Strategy 1: Check reverse graph directly
96
+ if citation_string in _reverse_graph:
97
+ refs = _reverse_graph[citation_string]
98
+ if refs:
99
+ return refs[0]
100
+
101
+ # Strategy 2: Normalise and check title index
102
+ normalised = re.sub(r'[^\w\s]', '', citation_string.lower())[:50]
103
+ if normalised in _title_to_id:
104
+ return _title_to_id[normalised]
105
+
106
+ # Strategy 3: Partial match on title index
107
+ for title, jid in _title_to_id.items():
108
+ if len(normalised) > 10 and normalised[:20] in title:
109
+ return jid
110
+
111
+ return None
112
+
113
+
114
+ def get_precedent_chain(
115
+ judgment_ids: List[str],
116
+ max_precedents: int = 3
117
+ ) -> List[Dict]:
118
+ """
119
+ Given a list of retrieved judgment IDs, return their cited predecessors.
120
+
121
+ Args:
122
+ judgment_ids: IDs of judgments already retrieved by FAISS
123
+ max_precedents: maximum number of precedent chunks to return
124
+
125
+ Returns:
126
+ List of precedent dicts with same structure as regular chunks,
127
+ plus 'is_precedent': True and 'cited_by' field.
128
+ """
129
+ if not _loaded or not _graph:
130
+ return []
131
+
132
+ precedents = []
133
+ seen_ids = set(judgment_ids)
134
+
135
+ for jid in judgment_ids:
136
+ citations = _graph.get(jid, [])
137
+ if not citations:
138
+ continue
139
+
140
+ for citation_ref in citations[:3]: # max 3 citations per judgment
141
+ resolved_id = _resolve_citation_to_judgment(citation_ref)
142
+
143
+ if not resolved_id or resolved_id in seen_ids:
144
+ continue
145
+
146
+ # Get text from parent store
147
+ text = _parent_store.get(resolved_id, "")
148
+ if not text:
149
+ continue
150
+
151
+ seen_ids.add(resolved_id)
152
+
153
+ # Extract a useful excerpt — first 1500 chars after any header
154
+ excerpt = text[:1500].strip()
155
+
156
+ precedents.append({
157
+ "judgment_id": resolved_id,
158
+ "chunk_id": f"{resolved_id}_precedent",
159
+ "text": excerpt,
160
+ "title": f"Precedent: {citation_ref[:80]}",
161
+ "year": resolved_id.split("_")[1] if "_" in resolved_id else "",
162
+ "source_type": "case_law",
163
+ "is_precedent": True,
164
+ "cited_by": jid,
165
+ "citation_ref": citation_ref,
166
+ "similarity_score": 0.5 # precedents are added, not ranked
167
+ })
168
+
169
+ if len(precedents) >= max_precedents:
170
+ break
171
+
172
+ if len(precedents) >= max_precedents:
173
+ break
174
+
175
+ if precedents:
176
+ logger.info(f"Precedent chain: added {len(precedents)} predecessor judgments")
177
+
178
+ return precedents
179
+
180
+
181
+ def get_citation_count(judgment_id: str) -> int:
182
+ """How many times has this judgment been cited by others."""
183
+ count = 0
184
+ for citations in _graph.values():
185
+ for c in citations:
186
+ resolved = _resolve_citation_to_judgment(c)
187
+ if resolved == judgment_id:
188
+ count += 1
189
+ return count
190
+
191
+
192
+ def is_loaded() -> bool:
193
+ return _loaded