b24122 commited on
Commit
00d8d42
·
0 Parent(s):

Initial commit

Browse files
attached_assets/raggy (3)_1753453411048.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """raggy.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1qpREkLNBZPP521tI9IvkNaB3FaLnlH9d
8
+ """
9
+
10
+ from google.colab import drive
11
+ drive.mount('/content/drive')
12
+
13
+
14
+ !pip install faiss-cpu --quiet
15
+
16
+
17
+ !pip install faiss-cpu -q
18
+
19
+
20
+ import zipfile
21
+ import os
22
+
23
+ zipPath = "/content/drive/MyDrive/legalbert_epoch4.zip"
24
+ extractPath = "/content/legalbert_model"
25
+
26
+ with zipfile.ZipFile(zipPath, 'r') as zipRef:
27
+ zipRef.extractall(extractPath)
28
+
29
+ print("Model unzipped at:", extractPath)
30
+
31
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
32
+ import torch
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained("/content/legalbert_model")
37
+ legalBertModel = AutoModelForSequenceClassification.from_pretrained("/content/legalbert_model").to(device)
38
+
39
+ print("Model and tokenizer loaded on", device)
40
+
41
+ import torch.nn.functional as F
42
+
43
+ def predictVerdict(inputText):
44
+ inputs = tokenizer(inputText, return_tensors="pt", truncation=True, padding=True).to(device)
45
+ with torch.no_grad():
46
+ logits = legalBertModel(**inputs).logits
47
+ probabilities = F.softmax(logits, dim=1)
48
+ predictedLabel = torch.argmax(probabilities, dim=1).item()
49
+ return "guilty" if predictedLabel == 1 else "not guilty"
50
+
51
+ def getConfidence(inputText):
52
+ inputs = tokenizer(inputText, return_tensors="pt", truncation=True, padding=True).to(device)
53
+ with torch.no_grad():
54
+ logits = legalBertModel(**inputs).logits
55
+ probabilities = F.softmax(logits, dim=1)
56
+ return float(torch.max(probabilities).item())
57
+
58
+ inputText = "The accused was found in possession of stolen property and failed to provide a valid explanation."
59
+
60
+ verdict = predictVerdict(inputText)
61
+ confidence = getConfidence(inputText)
62
+
63
+ print("Verdict:", verdict)
64
+ print("Confidence:", confidence)
65
+
66
+ !pip install -q google-generativeai
67
+
68
+ import google.generativeai as genai
69
+ import os
70
+
71
+ apiKey = "AIzaSyB2MlvYuABxIQjs42lZsASp78q7F95NOgc"
72
+ genai.configure(api_key=apiKey)
73
+
74
+ model = genai.GenerativeModel("gemini-2.5-flash")
75
+
76
+ def retrieveDualSupportChunks(inputText, geminiQueryModel):
77
+ try:
78
+ geminiQuery = generateSearchQueryFromCase(inputText, geminiQueryModel)
79
+ except:
80
+ geminiQuery = None
81
+
82
+ supportFromCase, _ = retrieveSupportChunksParallel(inputText)
83
+ supportFromQuery, _ = retrieveSupportChunksParallel(geminiQuery or inputText)
84
+
85
+ combinedSupport = {}
86
+ for key in supportFromCase:
87
+ combined = supportFromCase[key] + supportFromQuery[key]
88
+ seen = set()
89
+ unique = []
90
+ for chunk in combined:
91
+ if isinstance(chunk, str):
92
+ rep = chunk
93
+ else:
94
+ rep = chunk.get("text") or chunk.get("description") or chunk.get("section_desc") or str(chunk)
95
+ if rep not in seen:
96
+ seen.add(rep)
97
+ unique.append(chunk)
98
+ if len(unique) ==10:
99
+ break
100
+ combinedSupport[key] = unique
101
+
102
+ return combinedSupport, geminiQuery
103
+
104
+ import json
105
+
106
+ path = "/content/drive/MyDrive/faiss_indexes/constitution_bge_chunks.json"
107
+
108
+ with open(path, "r", encoding="utf-8") as f:
109
+ data = json.load(f)
110
+
111
+
112
+ for i, item in enumerate(data[:5]):
113
+ print(f"🔹 Chunk {i+1}:\n{item}\n")
114
+
115
+ import json
116
+
117
+ path="/content/drive/MyDrive/faiss_indexes/constitution_chunks.json"
118
+ with open(path, "r", encoding="utf-8") as f:
119
+ data = json.load(f)
120
+
121
+
122
+ for i, item in enumerate(data[:5]):
123
+ print(f"🔹 Chunk {i+1}:\n{item}\n")
124
+
125
+
126
+ import faiss
127
+ import numpy as np
128
+ import json
129
+ import pickle
130
+ from sentence_transformers import SentenceTransformer
131
+
132
+ encoder = SentenceTransformer('BAAI/bge-large-en-v1.5')
133
+ basePath = "/content/drive/MyDrive/faiss_indexes"
134
+
135
+ def loadFaissIndexAndChunks(indexPath, chunkPath):
136
+ index = faiss.read_index(indexPath)
137
+ with open(chunkPath, 'rb' if chunkPath.endswith('.pkl') else 'r') as f:
138
+ chunks = pickle.load(f) if chunkPath.endswith('.pkl') else json.load(f)
139
+ return index, chunks
140
+
141
+ def search(index, chunks, queryEmbedding, topK):
142
+ D, I = index.search(queryEmbedding, topK)
143
+ results = []
144
+ for score, idx in zip(D[0], I[0]):
145
+ if idx < len(chunks):
146
+ results.append((score, chunks[idx]))
147
+ return results
148
+
149
+ from concurrent.futures import ThreadPoolExecutor
150
+ def retrieveSupportChunksParallel(inputText):
151
+ queryEmbedding = encoder.encode([inputText], normalize_embeddings=True).astype('float32')
152
+ faiss.normalize_L2(queryEmbedding)
153
+
154
+ logs = {"query": inputText}
155
+
156
+ def retrieve(name):
157
+ idx, chunks = preloadedIndexes[name]
158
+ results = search(idx, chunks, queryEmbedding, 5)
159
+ return name, [c[1] for c in results]
160
+
161
+ support = {}
162
+ with ThreadPoolExecutor(max_workers=6) as executor:
163
+ futures = [executor.submit(retrieve, name) for name in preloadedIndexes.keys()]
164
+ for f in futures:
165
+ name, topChunks = f.result()
166
+ support[name] = topChunks
167
+
168
+ logs["supportChunksUsed"] = support
169
+ return support, logs
170
+
171
+ preloadedIndexes = {
172
+ "constitution": loadFaissIndexAndChunks(f"{basePath}/constitution_bgeLarge.index", f"{basePath}/constitution_chunks.json"),
173
+ "ipcSections": loadFaissIndexAndChunks(f"{basePath}/ipc_bgeLarge.index", f"{basePath}/ipc_chunks.json"),
174
+ "ipcCase": loadFaissIndexAndChunks(f"{basePath}/ipc_case_flat.index", f"{basePath}/ipc_case_chunks.json"),
175
+ "statutes": loadFaissIndexAndChunks(f"{basePath}/statute_index.faiss", f"{basePath}/statute_chunks.pkl"),
176
+ "qaTexts": loadFaissIndexAndChunks(f"{basePath}/qa_faiss_index.idx", f"{basePath}/qa_text_chunks.json"),
177
+ "caseLaw": loadFaissIndexAndChunks(f"{basePath}/case_faiss.index", f"{basePath}/case_chunks.pkl")
178
+ }
179
+
180
+ def generateSearchQueryFromCase(caseFacts, geminiModel, verbose=False):
181
+ prompt = f"""
182
+ You are a legal assistant for a retrieval system based on Indian criminal law.
183
+
184
+ Given the case facts below, generate a **concise and focused search query** with **only the most relevant legal keywords**. These should include:
185
+
186
+ - Specific **IPC sections**
187
+ - Core **legal concepts** (e.g., "right of private defence", "criminal breach of trust")
188
+ - **Crime type** (e.g., "assault", "corruption")
189
+ - Any relevant **procedural issue** (e.g., "absence of intent", "lack of evidence")
190
+
191
+ Do **not** include:
192
+ - Full sentences
193
+ - Personal names
194
+ - Generic or vague words (e.g., "man", "incident", "case", "situation")
195
+
196
+ Keep the query under **20 words**. Separate terms by commas if needed. Optimize for legal document search.
197
+
198
+ Case Facts:
199
+ \"\"\"{caseFacts}\"\"\"
200
+
201
+ Return only the search query, no explanation or prefix:
202
+ """
203
+ response = geminiModel.generate_content(prompt)
204
+ query = response.text.replace("Search Query:", "").strip().strip('"').replace("\n", "")
205
+
206
+ if verbose:
207
+ print("RAG Query:", query)
208
+
209
+ return query
210
+
211
+ def buildGeminiPrompt(inputText, modelVerdict, confidence, support, query=None):
212
+ verdictOutcome = "a loss for the person" if modelVerdict.lower() == "guilty" else "in favor of the person"
213
+
214
+ prompt = f"""You are a judge evaluating a legal dispute under Indian law.
215
+
216
+ ### Case Facts:
217
+ {inputText}
218
+
219
+ ### Initial Model Verdict:
220
+ {modelVerdict.upper()} (Confidence: {confidence * 100:.2f}%)
221
+ This verdict is interpreted as {verdictOutcome}.
222
+ """
223
+
224
+ if query:
225
+ prompt += f"\n### Legal Query Used:\n{query}\n"
226
+
227
+ prompt += "\n---\n\n### Legal References Retrieved:\n\n#### Constitution Articles (Top 5):\n"
228
+ for i, art in enumerate(support.get("constitution", [])):
229
+ prompt += f"- {i+1}. {str(art)}\n"
230
+
231
+ prompt += "\n#### IPC Sections (Top 5):\n"
232
+ for i, sec in enumerate(support.get("ipcSections", [])):
233
+ prompt += f"- {i+1}. {str(sec)}\n"
234
+
235
+ prompt += "\n#### IPC Case Law (Top 5):\n"
236
+ for i, case in enumerate(support.get("ipcCase", [])):
237
+ prompt += f"- {i+1}. {str(case)}\n"
238
+
239
+ prompt += "\n#### Statutes (Top 5):\n"
240
+ for i, stat in enumerate(support.get("statutes", [])):
241
+ prompt += f"- {i+1}. {str(stat)}\n"
242
+
243
+ prompt += "\n#### QA Texts (Top 5):\n"
244
+ for i, qa in enumerate(support.get("qaTexts", [])):
245
+ prompt += f"- {i+1}. {str(qa)}\n"
246
+
247
+ prompt += "\n#### General Case Law (Top 5):\n"
248
+ for i, gcase in enumerate(support.get("caseLaw", [])):
249
+ prompt += f"- {i+1}. {str(gcase)}\n"
250
+
251
+ prompt += f"""
252
+
253
+ ---
254
+
255
+ ### Instructions to the Judge (You):
256
+
257
+ 1. Review the legal materials provided:
258
+ - Identify which Constitution articles, IPC sections, statutes, and case laws are relevant to the facts.
259
+ - Also note and explain which retrieved references are **not applicable** or irrelevant.
260
+
261
+ 2. If relevant past cases appear in the retrieved materials, summarize them and analyze whether they support or contradict the model’s verdict.
262
+
263
+ 3. Using the above, assess the model's prediction:
264
+ - If confidence is below 60%, you may revise or retain it.
265
+ - If confidence is 60% or higher, retain unless clear legal grounds exist to challenge it.
266
+
267
+ 4. Provide a thorough and formal legal explanation that:
268
+ - Justifies the final decision using legal logic
269
+ - Cites relevant IPCs, constitutional provisions, statutes, and precedents
270
+ - Explains any reasoning for overriding the model's prediction, if applicable
271
+
272
+ 5. Conclude with the following lines, formatted as shown:
273
+
274
+ Final Verdict: Guilty or Not Guilty
275
+ Verdict Changed: Yes or No
276
+
277
+ Respond in the tone of a formal Indian judge. Your explanation should reflect reasoning, neutrality, and respect for legal procedure.
278
+ """
279
+ return prompt
280
+
281
+ import re
282
+
283
+ def extractFinalVerdict(geminiOutput):
284
+ verdictMatch = re.search(r"final verdict\s*[:\-]\s*(guilty|not guilty)", geminiOutput, re.IGNORECASE)
285
+ changedMatch = re.search(r"verdict changed\s*[:\-]\s*(yes|no)", geminiOutput, re.IGNORECASE)
286
+
287
+ finalVerdict = verdictMatch.group(1).lower() if verdictMatch else None
288
+ verdictChanged = "changed" if changedMatch and changedMatch.group(1).lower() == "yes" else "not changed"
289
+
290
+ return finalVerdict, verdictChanged
291
+
292
+ def evaluateCaseWithGemini(inputText, modelVerdict, confidence, retrieveFn, geminiQueryModel=None):
293
+ try:
294
+ if geminiQueryModel:
295
+ support, searchQuery = retrieveDualSupportChunks(inputText, geminiQueryModel)
296
+ else:
297
+ support, _ = retrieveFn(inputText)
298
+ searchQuery = inputText
299
+
300
+ prompt = buildGeminiPrompt(inputText, modelVerdict, confidence, support, searchQuery)
301
+ response = model.generate_content(prompt)
302
+ geminiOutput = response.text
303
+
304
+ finalVerdict, verdictChanged = extractFinalVerdict(geminiOutput)
305
+
306
+ logs = {
307
+ "inputText": inputText,
308
+ "modelVerdict": modelVerdict,
309
+ "confidence": confidence,
310
+ "support": support,
311
+ "promptToGemini": prompt,
312
+ "geminiOutput": geminiOutput,
313
+ "finalVerdictByGemini": finalVerdict,
314
+ "verdictChanged": verdictChanged,
315
+ "ragSearchQuery": searchQuery
316
+ }
317
+
318
+ return logs
319
+
320
+ except Exception as e:
321
+ return dict(
322
+ error=str(e),
323
+ inputText=inputText,
324
+ modelVerdict=modelVerdict,
325
+ confidence=confidence,
326
+ ragSearchQuery=None,
327
+ support=None,
328
+ promptToGemini=None,
329
+ geminiOutput=None,
330
+ finalVerdictByGemini=None,
331
+ verdictChanged=None
332
+ )
333
+
334
+ import pandas as pd
335
+
336
+ df=pd.read_csv('/content/drive/MyDrive/Extracted/LegalRAGSystem/ILDC/test.csv')
337
+
338
+ df['Label'][1971]
339
+
340
+ inputText = df['Input'][1971]
341
+
342
+ verdict = predictVerdict(inputText)
343
+ confidence = getConfidence(inputText)
344
+
345
+ logs = evaluateCaseWithGemini(
346
+ inputText=inputText,
347
+ modelVerdict=verdict,
348
+ confidence=confidence,
349
+ retrieveFn=retrieveSupportChunksParallel,
350
+ geminiQueryModel=model
351
+ )
352
+
353
+ print("🔍 Query sent to RAG:", logs["ragSearchQuery"])
354
+ print(logs['modelVerdict'])
355
+ print(logs['confidence'])
356
+ # print("\n📜 Prompt to Gemini:\n", logs["promptToGemini"])
357
+ print("\n🧑‍⚖️ Gemini Verdict Output:\n", logs["geminiOutput"])
358
+ print("\n✅ Final Verdict:", logs["finalVerdictByGemini"])
359
+ print("🔁 Verdict Changed:", logs["verdictChanged"])
360
+
361
+ # import random
362
+
363
+ # sampleIndices = random.sample(range(len(df)), 5)
364
+ # correctCount = 0
365
+ # total = 0
366
+
367
+ # for idx in sampleIndices:
368
+ # inputText = df['Input'][idx]
369
+ # trueLabel = int(df['Label'][idx])
370
+
371
+ # verdict = predictVerdict(inputText)
372
+ # confidence = getConfidence(inputText)
373
+
374
+ # result = evaluateCaseWithGemini(
375
+ # inputText=inputText,
376
+ # modelVerdict=verdict,
377
+ # confidence=confidence,
378
+ # retrieveFn=retrieveSupportChunksParallel,
379
+ # geminiQueryModel=model
380
+ # )
381
+
382
+ # predicted = result.get("finalVerdictByGemini")
383
+ # if predicted is None:
384
+ # continue
385
+
386
+ # predictedLabel = 1 if predicted.lower() == "guilty" else 0
387
+
388
+ # print("Index:", idx)
389
+ # print("True Label:", trueLabel)
390
+ # print("Predicted Verdict:", predicted)
391
+ # print("Verdict Changed:", result.get("verdictChanged"))
392
+ # print("Match:", predictedLabel == trueLabel)
393
+ # print("----")
394
+
395
+ # correctCount += int(predictedLabel == trueLabel)
396
+ # total += 1
397
+
398
+ # print("Samples Evaluated:", total)
399
+ # print("Gemini Final Verdict Accuracy:", correctCount / total if total else 0)
400
+