b24122 commited on
Commit
5d0a351
·
1 Parent(s): 9ae222c

Improve legal case evaluation with Gemini AI and enhanced RAG system

Browse files

Refactors GeminiService and RAGService to improve case evaluation using a dual retrieval RAG with LegalBERT predictions and Gemini AI evaluation.

Replit-Commit-Author: Agent
Replit-Commit-Session-Id: 63975d62-3d3b-48af-8685-b7e915f31f2b
Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/a5a12774-3181-414d-89e4-a4da8e3fb1ca/63975d62-3d3b-48af-8685-b7e915f31f2b/i8A93Md

app/api/routes.py CHANGED
@@ -47,28 +47,21 @@ async def analyze_case(request: CaseAnalysisRequest):
47
 
48
  logger.info(f"Initial verdict: {initial_verdict}, confidence: {confidence}")
49
 
50
- # Step 2: Retrieve supporting legal documents using RAG
51
- if request.useQueryGeneration:
52
- support_chunks, search_query = rag_service.retrieveDualSupportChunks(
53
- request.caseText, gemini_service
54
- )
55
- else:
56
- support_chunks, logs = rag_service.retrieveSupportChunksParallel(request.caseText)
57
- search_query = request.caseText
58
-
59
- logger.info(f"Retrieved support chunks from {len(support_chunks)} sources")
60
-
61
- # Step 3: Evaluate with Gemini AI
62
  evaluation_result = gemini_service.evaluateCaseWithGemini(
63
  inputText=request.caseText,
64
  modelVerdict=initial_verdict,
65
  confidence=confidence,
66
- support=support_chunks,
67
- searchQuery=search_query
68
  )
69
 
 
 
 
70
  logger.info(f"Gemini evaluation completed. Final verdict: {evaluation_result.get('finalVerdictByGemini')}")
71
 
 
72
  return CaseAnalysisResponse(
73
  initialVerdict=initial_verdict,
74
  initialConfidence=confidence,
 
47
 
48
  logger.info(f"Initial verdict: {initial_verdict}, confidence: {confidence}")
49
 
50
+ # Step 2: Evaluate with Gemini AI using RAG
 
 
 
 
 
 
 
 
 
 
 
51
  evaluation_result = gemini_service.evaluateCaseWithGemini(
52
  inputText=request.caseText,
53
  modelVerdict=initial_verdict,
54
  confidence=confidence,
55
+ retrieveFn=rag_service,
56
+ geminiQueryModel=gemini_service if request.useQueryGeneration else None
57
  )
58
 
59
+ logger.info(f"Retrieved support chunks from RAG system")
60
+ search_query = evaluation_result.get("ragSearchQuery", request.caseText)
61
+
62
  logger.info(f"Gemini evaluation completed. Final verdict: {evaluation_result.get('finalVerdictByGemini')}")
63
 
64
+ support_chunks = evaluation_result.get("support", {})
65
  return CaseAnalysisResponse(
66
  initialVerdict=initial_verdict,
67
  initialConfidence=confidence,
app/services/gemini_service.py CHANGED
@@ -22,7 +22,7 @@ class GeminiService:
22
  except Exception as e:
23
  logger.error(f"Failed to initialize Gemini client: {str(e)}")
24
 
25
- def generateSearchQueryFromCase(self, caseFacts: str, verbose: bool = False) -> str:
26
  if not self.client:
27
  raise ValueError("Gemini client not initialized")
28
 
@@ -58,7 +58,7 @@ Return only the search query, no explanation or prefix:
58
  if response.text:
59
  query = response.text.replace("Search Query:", "").strip().strip('"').replace("\n", "")
60
  else:
61
- query = caseFacts[:50] # Fallback to first 50 chars
62
 
63
  if verbose:
64
  logger.info(f"Generated RAG Query: {query}")
@@ -68,18 +68,18 @@ Return only the search query, no explanation or prefix:
68
  logger.error(f"Error generating search query: {str(e)}")
69
  raise ValueError(f"Search query generation failed: {str(e)}")
70
 
71
- def _build_gemini_prompt(self, input_text: str, model_verdict: str, confidence: float,
72
- support: Dict[str, List], query: Optional[str] = None) -> str:
73
- verdict_outcome = "a loss for the person" if model_verdict.lower() == "guilty" else "in favor of the person"
74
 
75
  prompt = f"""You are a judge evaluating a legal dispute under Indian law.
76
 
77
  ### Case Facts:
78
- {input_text}
79
 
80
  ### Initial Model Verdict:
81
- {model_verdict.upper()} (Confidence: {confidence * 100:.2f}%)
82
- This verdict is interpreted as {verdict_outcome}.
83
  """
84
 
85
  if query:
@@ -122,8 +122,8 @@ This verdict is interpreted as {verdict_outcome}.
122
  2. If relevant past cases appear in the retrieved materials, summarize them and analyze whether they support or contradict the model's verdict.
123
 
124
  3. Using the above, assess the model's prediction:
125
- - If confidence is below {settings.confidence_threshold * 100}%, you may revise or retain it.
126
- - If confidence is {settings.confidence_threshold * 100}% or higher, retain unless clear legal grounds exist to challenge it.
127
 
128
  4. Provide a thorough and formal legal explanation that:
129
  - Justifies the final decision using legal logic
@@ -139,31 +139,33 @@ Respond in the tone of a formal Indian judge. Your explanation should reflect re
139
  """
140
  return prompt
141
 
142
- def _extract_final_verdict(self, gemini_output: str) -> tuple[Optional[str], str]:
143
- verdict_match = re.search(r"final verdict\s*[:\-]\s*(guilty|not guilty)", gemini_output, re.IGNORECASE)
144
- changed_match = re.search(r"verdict changed\s*[:\-]\s*(yes|no)", gemini_output, re.IGNORECASE)
145
 
146
- final_verdict = verdict_match.group(1).lower() if verdict_match else None
147
- verdict_changed = "changed" if changed_match and changed_match.group(1).lower() == "yes" else "not changed"
148
 
149
- return final_verdict, verdict_changed
150
 
151
- def evaluateCaseWithGemini(self, inputText: str, modelVerdict: str, confidence: float,
152
- support: Dict[str, List], searchQuery: str) -> Dict[str, Any]:
153
- if not self.client:
154
- raise ValueError("Gemini client not initialized")
155
-
156
  try:
157
- prompt = self._build_gemini_prompt(inputText, modelVerdict, confidence, support, searchQuery)
158
-
 
 
 
 
 
159
  response = self.client.models.generate_content(
160
  model=settings.gemini_model,
161
  contents=prompt
162
  )
163
-
164
  geminiOutput = response.text if response.text else "No response from Gemini"
165
- finalVerdict, verdictChanged = self._extract_final_verdict(geminiOutput)
166
-
 
167
  logs = {
168
  "inputText": inputText,
169
  "modelVerdict": modelVerdict,
@@ -175,16 +177,16 @@ Respond in the tone of a formal Indian judge. Your explanation should reflect re
175
  "verdictChanged": verdictChanged,
176
  "ragSearchQuery": searchQuery
177
  }
178
-
179
  return logs
 
180
  except Exception as e:
181
- logger.error(f"Error in Gemini evaluation: {str(e)}")
182
  return {
183
  "error": str(e),
184
  "inputText": inputText,
185
  "modelVerdict": modelVerdict,
186
  "confidence": confidence,
187
- "ragSearchQuery": searchQuery,
188
  "support": None,
189
  "promptToGemini": None,
190
  "geminiOutput": None,
 
22
  except Exception as e:
23
  logger.error(f"Failed to initialize Gemini client: {str(e)}")
24
 
25
+ def generateSearchQueryFromCase(self, caseFacts: str, geminiModel=None, verbose: bool = False) -> str:
26
  if not self.client:
27
  raise ValueError("Gemini client not initialized")
28
 
 
58
  if response.text:
59
  query = response.text.replace("Search Query:", "").strip().strip('"').replace("\n", "")
60
  else:
61
+ query = caseFacts[:50] # Fallback
62
 
63
  if verbose:
64
  logger.info(f"Generated RAG Query: {query}")
 
68
  logger.error(f"Error generating search query: {str(e)}")
69
  raise ValueError(f"Search query generation failed: {str(e)}")
70
 
71
+ def buildGeminiPrompt(self, inputText: str, modelVerdict: str, confidence: float,
72
+ support: Dict[str, List], query: Optional[str] = None) -> str:
73
+ verdictOutcome = "a loss for the person" if modelVerdict.lower() == "guilty" else "in favor of the person"
74
 
75
  prompt = f"""You are a judge evaluating a legal dispute under Indian law.
76
 
77
  ### Case Facts:
78
+ {inputText}
79
 
80
  ### Initial Model Verdict:
81
+ {modelVerdict.upper()} (Confidence: {confidence * 100:.2f}%)
82
+ This verdict is interpreted as {verdictOutcome}.
83
  """
84
 
85
  if query:
 
122
  2. If relevant past cases appear in the retrieved materials, summarize them and analyze whether they support or contradict the model's verdict.
123
 
124
  3. Using the above, assess the model's prediction:
125
+ - If confidence is below 60%, you may revise or retain it.
126
+ - If confidence is 60% or higher, retain unless clear legal grounds exist to challenge it.
127
 
128
  4. Provide a thorough and formal legal explanation that:
129
  - Justifies the final decision using legal logic
 
139
  """
140
  return prompt
141
 
142
+ def extractFinalVerdict(self, geminiOutput: str) -> tuple[Optional[str], str]:
143
+ verdictMatch = re.search(r"final verdict\s*[:\-]\s*(guilty|not guilty)", geminiOutput, re.IGNORECASE)
144
+ changedMatch = re.search(r"verdict changed\s*[:\-]\s*(yes|no)", geminiOutput, re.IGNORECASE)
145
 
146
+ finalVerdict = verdictMatch.group(1).lower() if verdictMatch else None
147
+ verdictChanged = "changed" if changedMatch and changedMatch.group(1).lower() == "yes" else "not changed"
148
 
149
+ return finalVerdict, verdictChanged
150
 
151
+ def evaluateCaseWithGemini(self, inputText: str, modelVerdict: str, confidence: float,
152
+ retrieveFn, geminiQueryModel=None):
 
 
 
153
  try:
154
+ if geminiQueryModel:
155
+ support, searchQuery = retrieveFn.retrieveDualSupportChunks(inputText, self)
156
+ else:
157
+ support, _ = retrieveFn.retrieveSupportChunksParallel(inputText)
158
+ searchQuery = inputText
159
+
160
+ prompt = self.buildGeminiPrompt(inputText, modelVerdict, confidence, support, searchQuery)
161
  response = self.client.models.generate_content(
162
  model=settings.gemini_model,
163
  contents=prompt
164
  )
 
165
  geminiOutput = response.text if response.text else "No response from Gemini"
166
+
167
+ finalVerdict, verdictChanged = self.extractFinalVerdict(geminiOutput)
168
+
169
  logs = {
170
  "inputText": inputText,
171
  "modelVerdict": modelVerdict,
 
177
  "verdictChanged": verdictChanged,
178
  "ragSearchQuery": searchQuery
179
  }
180
+
181
  return logs
182
+
183
  except Exception as e:
 
184
  return {
185
  "error": str(e),
186
  "inputText": inputText,
187
  "modelVerdict": modelVerdict,
188
  "confidence": confidence,
189
+ "ragSearchQuery": None,
190
  "support": None,
191
  "promptToGemini": None,
192
  "geminiOutput": None,
app/services/rag_service.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import os
 
3
  from concurrent.futures import ThreadPoolExecutor
4
  from typing import Dict, List, Any, Tuple
5
  from app.core.config import settings
@@ -16,87 +17,143 @@ class RAGService:
16
 
17
  def _initialize_encoder(self):
18
  try:
19
- logger.info(f"Sentence transformer placeholder initialized")
20
- # TODO: Initialize actual sentence transformer when dependencies are available
 
 
 
 
21
  self.encoder = "placeholder"
22
  except Exception as e:
23
- logger.error(f"Failed to initialize encoder: {str(e)}")
 
24
 
25
- def _load_faiss_index_and_chunks(self, indexPath: str, chunkPath: str) -> Tuple[Any, List]:
26
  try:
27
  if not os.path.exists(indexPath) or not os.path.exists(chunkPath):
28
  logger.warning(f"Missing files: {indexPath} or {chunkPath}")
29
  return None, []
30
 
31
- # TODO: Load actual FAISS index when faiss-cpu is available
 
 
 
 
 
32
 
33
  if chunkPath.endswith('.pkl'):
34
- logger.info(f"Placeholder for pickle file: {chunkPath}")
35
- chunks = []
36
  else:
37
- try:
38
- with open(chunkPath, 'r', encoding='utf-8') as f:
39
- chunks = json.load(f)
40
- except:
41
- chunks = []
42
 
43
- logger.info(f"Loaded index placeholder from {indexPath} with {len(chunks)} chunks")
44
- return "placeholder_index", chunks
45
  except Exception as e:
46
  logger.error(f"Failed to load index {indexPath}: {str(e)}")
47
  return None, []
48
 
49
  def _load_indexes(self):
50
- indexConfigs = {
51
- "constitution": (settings.constitution_index_path, settings.constitution_chunks_path),
52
- "ipcSections": (settings.ipc_index_path, settings.ipc_chunks_path),
53
- "ipcCase": (settings.ipc_case_index_path, settings.ipc_case_chunks_path),
54
- "statutes": (settings.statute_index_path, settings.statute_chunks_path),
55
- "qaTexts": (settings.qa_index_path, settings.qa_chunks_path),
56
- "caseLaw": (settings.case_law_index_path, settings.case_law_chunks_path)
 
57
  }
58
 
59
- for name, (indexPath, chunkPath) in indexConfigs.items():
60
- indexData = self._load_faiss_index_and_chunks(indexPath, chunkPath)
61
- if indexData[0] is not None:
62
- self.preloadedIndexes[name] = indexData
63
- logger.info(f"Successfully loaded {name} index placeholder")
64
- else:
65
- logger.warning(f"Failed to load {name} index")
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def retrieveSupportChunksParallel(self, inputText: str) -> Tuple[Dict[str, List], Dict]:
68
- logger.info("Using placeholder RAG retrieval")
69
-
70
- logs = {"query": inputText}
 
 
 
 
 
 
 
 
 
71
 
72
- # Return placeholder support chunks
73
- support = {}
74
- for name in ["constitution", "ipcSections", "ipcCase", "statutes", "qaTexts", "caseLaw"]:
75
- if name in self.preloadedIndexes:
76
- _, chunks = self.preloadedIndexes[name]
77
- support[name] = chunks[:5] if chunks else []
78
- else:
79
- support[name] = []
80
-
81
- logs["supportChunksUsed"] = str(support)
82
- return support, logs
83
-
84
- def retrieveDualSupportChunks(self, inputText: str, geminiService) -> Tuple[Dict[str, List], str]:
85
  try:
86
- # Generate search query using Gemini
87
- geminiQuery = None
88
- try:
89
- geminiQuery = geminiService.generateSearchQueryFromCase(inputText)
90
- except Exception as e:
91
- logger.warning(f"Failed to generate Gemini query: {str(e)}")
92
 
93
- # Use placeholder retrieval
94
- support, _ = self.retrieveSupportChunksParallel(inputText)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- return support, geminiQuery or inputText
97
  except Exception as e:
98
- logger.error(f"Error in dual support retrieval: {str(e)}")
99
- raise ValueError(f"Dual support retrieval failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def areIndexesLoaded(self) -> bool:
102
  return len(self.preloadedIndexes) > 0
 
1
  import json
2
  import os
3
+ import pickle
4
  from concurrent.futures import ThreadPoolExecutor
5
  from typing import Dict, List, Any, Tuple
6
  from app.core.config import settings
 
17
 
18
  def _initialize_encoder(self):
19
  try:
20
+ from sentence_transformers import SentenceTransformer
21
+ logger.info(f"Loading sentence transformer: {settings.sentence_transformer_model}")
22
+ self.encoder = SentenceTransformer(settings.sentence_transformer_model)
23
+ logger.info("Sentence transformer loaded successfully")
24
+ except ImportError:
25
+ logger.warning("sentence-transformers not installed - using placeholder mode")
26
  self.encoder = "placeholder"
27
  except Exception as e:
28
+ logger.error(f"Failed to load sentence transformer: {str(e)}")
29
+ self.encoder = "placeholder"
30
 
31
+ def loadFaissIndexAndChunks(self, indexPath: str, chunkPath: str) -> Tuple[Any, List]:
32
  try:
33
  if not os.path.exists(indexPath) or not os.path.exists(chunkPath):
34
  logger.warning(f"Missing files: {indexPath} or {chunkPath}")
35
  return None, []
36
 
37
+ try:
38
+ import faiss
39
+ index = faiss.read_index(indexPath)
40
+ except ImportError:
41
+ logger.warning("faiss-cpu not installed - returning placeholder")
42
+ return "placeholder_index", []
43
 
44
  if chunkPath.endswith('.pkl'):
45
+ with open(chunkPath, 'rb') as f:
46
+ chunks = pickle.load(f)
47
  else:
48
+ with open(chunkPath, 'r', encoding='utf-8') as f:
49
+ chunks = json.load(f)
 
 
 
50
 
51
+ logger.info(f"Loaded index from {indexPath} with {len(chunks)} chunks")
52
+ return index, chunks
53
  except Exception as e:
54
  logger.error(f"Failed to load index {indexPath}: {str(e)}")
55
  return None, []
56
 
57
  def _load_indexes(self):
58
+ basePath = settings.faiss_indexes_base_path
59
+ self.preloadedIndexes = {
60
+ "constitution": self.loadFaissIndexAndChunks(f"{basePath}/constitution_bgeLarge.index", f"{basePath}/constitution_chunks.json"),
61
+ "ipcSections": self.loadFaissIndexAndChunks(f"{basePath}/ipc_bgeLarge.index", f"{basePath}/ipc_chunks.json"),
62
+ "ipcCase": self.loadFaissIndexAndChunks(f"{basePath}/ipc_case_flat.index", f"{basePath}/ipc_case_chunks.json"),
63
+ "statutes": self.loadFaissIndexAndChunks(f"{basePath}/statute_index.faiss", f"{basePath}/statute_chunks.pkl"),
64
+ "qaTexts": self.loadFaissIndexAndChunks(f"{basePath}/qa_faiss_index.idx", f"{basePath}/qa_text_chunks.json"),
65
+ "caseLaw": self.loadFaissIndexAndChunks(f"{basePath}/case_faiss.index", f"{basePath}/case_chunks.pkl")
66
  }
67
 
68
+ # Remove failed loads
69
+ self.preloadedIndexes = {k: v for k, v in self.preloadedIndexes.items() if v[0] is not None}
70
+ logger.info(f"Successfully loaded {len(self.preloadedIndexes)} indexes")
71
+
72
+ def search(self, index: Any, chunks: List, queryEmbedding, topK: int) -> List[Tuple[float, Any]]:
73
+ try:
74
+ if index == "placeholder_index":
75
+ return [(0.5, chunk) for chunk in chunks[:topK]]
76
+
77
+ import faiss
78
+ D, I = index.search(queryEmbedding, topK)
79
+ results = []
80
+ for score, idx in zip(D[0], I[0]):
81
+ if idx < len(chunks):
82
+ results.append((score, chunks[idx]))
83
+ return results
84
+ except Exception as e:
85
+ logger.error(f"Search failed: {str(e)}")
86
+ return []
87
 
88
  def retrieveSupportChunksParallel(self, inputText: str) -> Tuple[Dict[str, List], Dict]:
89
+ if self.encoder == "placeholder":
90
+ logger.info("Using placeholder RAG retrieval")
91
+ logs = {"query": inputText}
92
+ support = {}
93
+ for name in ["constitution", "ipcSections", "ipcCase", "statutes", "qaTexts", "caseLaw"]:
94
+ if name in self.preloadedIndexes:
95
+ _, chunks = self.preloadedIndexes[name]
96
+ support[name] = chunks[:5] if chunks else []
97
+ else:
98
+ support[name] = []
99
+ logs["supportChunksUsed"] = support
100
+ return support, logs
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  try:
103
+ import faiss
104
+ queryEmbedding = self.encoder.encode([inputText], normalize_embeddings=True).astype('float32')
105
+ faiss.normalize_L2(queryEmbedding)
106
+
107
+ logs = {"query": inputText}
 
108
 
109
+ def retrieve(name):
110
+ if name not in self.preloadedIndexes:
111
+ return name, []
112
+ idx, chunks = self.preloadedIndexes[name]
113
+ results = self.search(idx, chunks, queryEmbedding, 5)
114
+ return name, [c[1] for c in results]
115
+
116
+ support = {}
117
+ with ThreadPoolExecutor(max_workers=6) as executor:
118
+ futures = [executor.submit(retrieve, name) for name in self.preloadedIndexes.keys()]
119
+ for f in futures:
120
+ name, topChunks = f.result()
121
+ support[name] = topChunks
122
+
123
+ logs["supportChunksUsed"] = support
124
+ return support, logs
125
 
 
126
  except Exception as e:
127
+ logger.error(f"Error retrieving support chunks: {str(e)}")
128
+ raise ValueError(f"Support chunk retrieval failed: {str(e)}")
129
+
130
+ def retrieveDualSupportChunks(self, inputText: str, geminiQueryModel):
131
+ try:
132
+ geminiQuery = geminiQueryModel.generateSearchQueryFromCase(inputText, geminiQueryModel)
133
+ except:
134
+ geminiQuery = None
135
+
136
+ supportFromCase, _ = self.retrieveSupportChunksParallel(inputText)
137
+ supportFromQuery, _ = self.retrieveSupportChunksParallel(geminiQuery or inputText)
138
+
139
+ combinedSupport = {}
140
+ for key in supportFromCase:
141
+ combined = supportFromCase[key] + supportFromQuery[key]
142
+ seen = set()
143
+ unique = []
144
+ for chunk in combined:
145
+ if isinstance(chunk, str):
146
+ rep = chunk
147
+ else:
148
+ rep = chunk.get("text") or chunk.get("description") or chunk.get("section_desc") or str(chunk)
149
+ if rep not in seen:
150
+ seen.add(rep)
151
+ unique.append(chunk)
152
+ if len(unique) == 10:
153
+ break
154
+ combinedSupport[key] = unique
155
+
156
+ return combinedSupport, geminiQuery
157
 
158
  def areIndexesLoaded(self) -> bool:
159
  return len(self.preloadedIndexes) > 0
attached_assets/raggy (3)_1753479511222.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """raggy.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1qpREkLNBZPP521tI9IvkNaB3FaLnlH9d
8
+ """
9
+
10
+ from google.colab import drive
11
+ drive.mount('/content/drive')
12
+
13
+
14
+ !pip install faiss-cpu --quiet
15
+
16
+
17
+ !pip install faiss-cpu -q
18
+
19
+
20
+ import zipfile
21
+ import os
22
+
23
+ zipPath = "/content/drive/MyDrive/legalbert_epoch4.zip"
24
+ extractPath = "/content/legalbert_model"
25
+
26
+ with zipfile.ZipFile(zipPath, 'r') as zipRef:
27
+ zipRef.extractall(extractPath)
28
+
29
+ print("Model unzipped at:", extractPath)
30
+
31
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
32
+ import torch
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained("/content/legalbert_model")
37
+ legalBertModel = AutoModelForSequenceClassification.from_pretrained("/content/legalbert_model").to(device)
38
+
39
+ print("Model and tokenizer loaded on", device)
40
+
41
+ import torch.nn.functional as F
42
+
43
+ def predictVerdict(inputText):
44
+ inputs = tokenizer(inputText, return_tensors="pt", truncation=True, padding=True).to(device)
45
+ with torch.no_grad():
46
+ logits = legalBertModel(**inputs).logits
47
+ probabilities = F.softmax(logits, dim=1)
48
+ predictedLabel = torch.argmax(probabilities, dim=1).item()
49
+ return "guilty" if predictedLabel == 1 else "not guilty"
50
+
51
+ def getConfidence(inputText):
52
+ inputs = tokenizer(inputText, return_tensors="pt", truncation=True, padding=True).to(device)
53
+ with torch.no_grad():
54
+ logits = legalBertModel(**inputs).logits
55
+ probabilities = F.softmax(logits, dim=1)
56
+ return float(torch.max(probabilities).item())
57
+
58
+ inputText = "The accused was found in possession of stolen property and failed to provide a valid explanation."
59
+
60
+ verdict = predictVerdict(inputText)
61
+ confidence = getConfidence(inputText)
62
+
63
+ print("Verdict:", verdict)
64
+ print("Confidence:", confidence)
65
+
66
+ !pip install -q google-generativeai
67
+
68
+ import google.generativeai as genai
69
+ import os
70
+
71
+ apiKey = "AIzaSyB2MlvYuABxIQjs42lZsASp78q7F95NOgc"
72
+ genai.configure(api_key=apiKey)
73
+
74
+ model = genai.GenerativeModel("gemini-2.5-flash")
75
+
76
+ def retrieveDualSupportChunks(inputText, geminiQueryModel):
77
+ try:
78
+ geminiQuery = generateSearchQueryFromCase(inputText, geminiQueryModel)
79
+ except:
80
+ geminiQuery = None
81
+
82
+ supportFromCase, _ = retrieveSupportChunksParallel(inputText)
83
+ supportFromQuery, _ = retrieveSupportChunksParallel(geminiQuery or inputText)
84
+
85
+ combinedSupport = {}
86
+ for key in supportFromCase:
87
+ combined = supportFromCase[key] + supportFromQuery[key]
88
+ seen = set()
89
+ unique = []
90
+ for chunk in combined:
91
+ if isinstance(chunk, str):
92
+ rep = chunk
93
+ else:
94
+ rep = chunk.get("text") or chunk.get("description") or chunk.get("section_desc") or str(chunk)
95
+ if rep not in seen:
96
+ seen.add(rep)
97
+ unique.append(chunk)
98
+ if len(unique) ==10:
99
+ break
100
+ combinedSupport[key] = unique
101
+
102
+ return combinedSupport, geminiQuery
103
+
104
+ import json
105
+
106
+ path = "/content/drive/MyDrive/faiss_indexes/constitution_bge_chunks.json"
107
+
108
+ with open(path, "r", encoding="utf-8") as f:
109
+ data = json.load(f)
110
+
111
+
112
+ for i, item in enumerate(data[:5]):
113
+ print(f"🔹 Chunk {i+1}:\n{item}\n")
114
+
115
+ import json
116
+
117
+ path="/content/drive/MyDrive/faiss_indexes/constitution_chunks.json"
118
+ with open(path, "r", encoding="utf-8") as f:
119
+ data = json.load(f)
120
+
121
+
122
+ for i, item in enumerate(data[:5]):
123
+ print(f"🔹 Chunk {i+1}:\n{item}\n")
124
+
125
+
126
+ import faiss
127
+ import numpy as np
128
+ import json
129
+ import pickle
130
+ from sentence_transformers import SentenceTransformer
131
+
132
+ encoder = SentenceTransformer('BAAI/bge-large-en-v1.5')
133
+ basePath = "/content/drive/MyDrive/faiss_indexes"
134
+
135
+ def loadFaissIndexAndChunks(indexPath, chunkPath):
136
+ index = faiss.read_index(indexPath)
137
+ with open(chunkPath, 'rb' if chunkPath.endswith('.pkl') else 'r') as f:
138
+ chunks = pickle.load(f) if chunkPath.endswith('.pkl') else json.load(f)
139
+ return index, chunks
140
+
141
+ def search(index, chunks, queryEmbedding, topK):
142
+ D, I = index.search(queryEmbedding, topK)
143
+ results = []
144
+ for score, idx in zip(D[0], I[0]):
145
+ if idx < len(chunks):
146
+ results.append((score, chunks[idx]))
147
+ return results
148
+
149
+ from concurrent.futures import ThreadPoolExecutor
150
+ def retrieveSupportChunksParallel(inputText):
151
+ queryEmbedding = encoder.encode([inputText], normalize_embeddings=True).astype('float32')
152
+ faiss.normalize_L2(queryEmbedding)
153
+
154
+ logs = {"query": inputText}
155
+
156
+ def retrieve(name):
157
+ idx, chunks = preloadedIndexes[name]
158
+ results = search(idx, chunks, queryEmbedding, 5)
159
+ return name, [c[1] for c in results]
160
+
161
+ support = {}
162
+ with ThreadPoolExecutor(max_workers=6) as executor:
163
+ futures = [executor.submit(retrieve, name) for name in preloadedIndexes.keys()]
164
+ for f in futures:
165
+ name, topChunks = f.result()
166
+ support[name] = topChunks
167
+
168
+ logs["supportChunksUsed"] = support
169
+ return support, logs
170
+
171
+ preloadedIndexes = {
172
+ "constitution": loadFaissIndexAndChunks(f"{basePath}/constitution_bgeLarge.index", f"{basePath}/constitution_chunks.json"),
173
+ "ipcSections": loadFaissIndexAndChunks(f"{basePath}/ipc_bgeLarge.index", f"{basePath}/ipc_chunks.json"),
174
+ "ipcCase": loadFaissIndexAndChunks(f"{basePath}/ipc_case_flat.index", f"{basePath}/ipc_case_chunks.json"),
175
+ "statutes": loadFaissIndexAndChunks(f"{basePath}/statute_index.faiss", f"{basePath}/statute_chunks.pkl"),
176
+ "qaTexts": loadFaissIndexAndChunks(f"{basePath}/qa_faiss_index.idx", f"{basePath}/qa_text_chunks.json"),
177
+ "caseLaw": loadFaissIndexAndChunks(f"{basePath}/case_faiss.index", f"{basePath}/case_chunks.pkl")
178
+ }
179
+
180
+ def generateSearchQueryFromCase(caseFacts, geminiModel, verbose=False):
181
+ prompt = f"""
182
+ You are a legal assistant for a retrieval system based on Indian criminal law.
183
+
184
+ Given the case facts below, generate a **concise and focused search query** with **only the most relevant legal keywords**. These should include:
185
+
186
+ - Specific **IPC sections**
187
+ - Core **legal concepts** (e.g., "right of private defence", "criminal breach of trust")
188
+ - **Crime type** (e.g., "assault", "corruption")
189
+ - Any relevant **procedural issue** (e.g., "absence of intent", "lack of evidence")
190
+
191
+ Do **not** include:
192
+ - Full sentences
193
+ - Personal names
194
+ - Generic or vague words (e.g., "man", "incident", "case", "situation")
195
+
196
+ Keep the query under **20 words**. Separate terms by commas if needed. Optimize for legal document search.
197
+
198
+ Case Facts:
199
+ \"\"\"{caseFacts}\"\"\"
200
+
201
+ Return only the search query, no explanation or prefix:
202
+ """
203
+ response = geminiModel.generate_content(prompt)
204
+ query = response.text.replace("Search Query:", "").strip().strip('"').replace("\n", "")
205
+
206
+ if verbose:
207
+ print("RAG Query:", query)
208
+
209
+ return query
210
+
211
+ def buildGeminiPrompt(inputText, modelVerdict, confidence, support, query=None):
212
+ verdictOutcome = "a loss for the person" if modelVerdict.lower() == "guilty" else "in favor of the person"
213
+
214
+ prompt = f"""You are a judge evaluating a legal dispute under Indian law.
215
+
216
+ ### Case Facts:
217
+ {inputText}
218
+
219
+ ### Initial Model Verdict:
220
+ {modelVerdict.upper()} (Confidence: {confidence * 100:.2f}%)
221
+ This verdict is interpreted as {verdictOutcome}.
222
+ """
223
+
224
+ if query:
225
+ prompt += f"\n### Legal Query Used:\n{query}\n"
226
+
227
+ prompt += "\n---\n\n### Legal References Retrieved:\n\n#### Constitution Articles (Top 5):\n"
228
+ for i, art in enumerate(support.get("constitution", [])):
229
+ prompt += f"- {i+1}. {str(art)}\n"
230
+
231
+ prompt += "\n#### IPC Sections (Top 5):\n"
232
+ for i, sec in enumerate(support.get("ipcSections", [])):
233
+ prompt += f"- {i+1}. {str(sec)}\n"
234
+
235
+ prompt += "\n#### IPC Case Law (Top 5):\n"
236
+ for i, case in enumerate(support.get("ipcCase", [])):
237
+ prompt += f"- {i+1}. {str(case)}\n"
238
+
239
+ prompt += "\n#### Statutes (Top 5):\n"
240
+ for i, stat in enumerate(support.get("statutes", [])):
241
+ prompt += f"- {i+1}. {str(stat)}\n"
242
+
243
+ prompt += "\n#### QA Texts (Top 5):\n"
244
+ for i, qa in enumerate(support.get("qaTexts", [])):
245
+ prompt += f"- {i+1}. {str(qa)}\n"
246
+
247
+ prompt += "\n#### General Case Law (Top 5):\n"
248
+ for i, gcase in enumerate(support.get("caseLaw", [])):
249
+ prompt += f"- {i+1}. {str(gcase)}\n"
250
+
251
+ prompt += f"""
252
+
253
+ ---
254
+
255
+ ### Instructions to the Judge (You):
256
+
257
+ 1. Review the legal materials provided:
258
+ - Identify which Constitution articles, IPC sections, statutes, and case laws are relevant to the facts.
259
+ - Also note and explain which retrieved references are **not applicable** or irrelevant.
260
+
261
+ 2. If relevant past cases appear in the retrieved materials, summarize them and analyze whether they support or contradict the model’s verdict.
262
+
263
+ 3. Using the above, assess the model's prediction:
264
+ - If confidence is below 60%, you may revise or retain it.
265
+ - If confidence is 60% or higher, retain unless clear legal grounds exist to challenge it.
266
+
267
+ 4. Provide a thorough and formal legal explanation that:
268
+ - Justifies the final decision using legal logic
269
+ - Cites relevant IPCs, constitutional provisions, statutes, and precedents
270
+ - Explains any reasoning for overriding the model's prediction, if applicable
271
+
272
+ 5. Conclude with the following lines, formatted as shown:
273
+
274
+ Final Verdict: Guilty or Not Guilty
275
+ Verdict Changed: Yes or No
276
+
277
+ Respond in the tone of a formal Indian judge. Your explanation should reflect reasoning, neutrality, and respect for legal procedure.
278
+ """
279
+ return prompt
280
+
281
+ import re
282
+
283
+ def extractFinalVerdict(geminiOutput):
284
+ verdictMatch = re.search(r"final verdict\s*[:\-]\s*(guilty|not guilty)", geminiOutput, re.IGNORECASE)
285
+ changedMatch = re.search(r"verdict changed\s*[:\-]\s*(yes|no)", geminiOutput, re.IGNORECASE)
286
+
287
+ finalVerdict = verdictMatch.group(1).lower() if verdictMatch else None
288
+ verdictChanged = "changed" if changedMatch and changedMatch.group(1).lower() == "yes" else "not changed"
289
+
290
+ return finalVerdict, verdictChanged
291
+
292
+ def evaluateCaseWithGemini(inputText, modelVerdict, confidence, retrieveFn, geminiQueryModel=None):
293
+ try:
294
+ if geminiQueryModel:
295
+ support, searchQuery = retrieveDualSupportChunks(inputText, geminiQueryModel)
296
+ else:
297
+ support, _ = retrieveFn(inputText)
298
+ searchQuery = inputText
299
+
300
+ prompt = buildGeminiPrompt(inputText, modelVerdict, confidence, support, searchQuery)
301
+ response = model.generate_content(prompt)
302
+ geminiOutput = response.text
303
+
304
+ finalVerdict, verdictChanged = extractFinalVerdict(geminiOutput)
305
+
306
+ logs = {
307
+ "inputText": inputText,
308
+ "modelVerdict": modelVerdict,
309
+ "confidence": confidence,
310
+ "support": support,
311
+ "promptToGemini": prompt,
312
+ "geminiOutput": geminiOutput,
313
+ "finalVerdictByGemini": finalVerdict,
314
+ "verdictChanged": verdictChanged,
315
+ "ragSearchQuery": searchQuery
316
+ }
317
+
318
+ return logs
319
+
320
+ except Exception as e:
321
+ return dict(
322
+ error=str(e),
323
+ inputText=inputText,
324
+ modelVerdict=modelVerdict,
325
+ confidence=confidence,
326
+ ragSearchQuery=None,
327
+ support=None,
328
+ promptToGemini=None,
329
+ geminiOutput=None,
330
+ finalVerdictByGemini=None,
331
+ verdictChanged=None
332
+ )
333
+
334
+ import pandas as pd
335
+
336
+ df=pd.read_csv('/content/drive/MyDrive/Extracted/LegalRAGSystem/ILDC/test.csv')
337
+
338
+ df['Label'][1971]
339
+
340
+ inputText = df['Input'][1971]
341
+
342
+ verdict = predictVerdict(inputText)
343
+ confidence = getConfidence(inputText)
344
+
345
+ logs = evaluateCaseWithGemini(
346
+ inputText=inputText,
347
+ modelVerdict=verdict,
348
+ confidence=confidence,
349
+ retrieveFn=retrieveSupportChunksParallel,
350
+ geminiQueryModel=model
351
+ )
352
+
353
+ print("🔍 Query sent to RAG:", logs["ragSearchQuery"])
354
+ print(logs['modelVerdict'])
355
+ print(logs['confidence'])
356
+ # print("\n📜 Prompt to Gemini:\n", logs["promptToGemini"])
357
+ print("\n🧑‍⚖️ Gemini Verdict Output:\n", logs["geminiOutput"])
358
+ print("\n✅ Final Verdict:", logs["finalVerdictByGemini"])
359
+ print("🔁 Verdict Changed:", logs["verdictChanged"])
360
+
361
+ # import random
362
+
363
+ # sampleIndices = random.sample(range(len(df)), 5)
364
+ # correctCount = 0
365
+ # total = 0
366
+
367
+ # for idx in sampleIndices:
368
+ # inputText = df['Input'][idx]
369
+ # trueLabel = int(df['Label'][idx])
370
+
371
+ # verdict = predictVerdict(inputText)
372
+ # confidence = getConfidence(inputText)
373
+
374
+ # result = evaluateCaseWithGemini(
375
+ # inputText=inputText,
376
+ # modelVerdict=verdict,
377
+ # confidence=confidence,
378
+ # retrieveFn=retrieveSupportChunksParallel,
379
+ # geminiQueryModel=model
380
+ # )
381
+
382
+ # predicted = result.get("finalVerdictByGemini")
383
+ # if predicted is None:
384
+ # continue
385
+
386
+ # predictedLabel = 1 if predicted.lower() == "guilty" else 0
387
+
388
+ # print("Index:", idx)
389
+ # print("True Label:", trueLabel)
390
+ # print("Predicted Verdict:", predicted)
391
+ # print("Verdict Changed:", result.get("verdictChanged"))
392
+ # print("Match:", predictedLabel == trueLabel)
393
+ # print("----")
394
+
395
+ # correctCount += int(predictedLabel == trueLabel)
396
+ # total += 1
397
+
398
+ # print("Samples Evaluated:", total)
399
+ # print("Gemini Final Verdict Accuracy:", correctCount / total if total else 0)
400
+