import torch import numpy as np from transformers import AutoTokenizer, AutoModel from typing import List, Dict, Any, Tuple, Optional import faiss import hashlib from tqdm import tqdm from groq import Groq import re import nltk from sklearn.metrics.pairwise import cosine_similarity import networkx as nx from collections import defaultdict import spacy from rank_bm25 import BM25Okapi # Global variables for models MODEL = None TOKENIZER = None GROQ_CLIENT = None NLP_MODEL = None DEVICE = None # Global indices DENSE_INDEX = None BM25_INDEX = None CONCEPT_GRAPH = None TOKEN_TO_CHUNKS = None CHUNKS_DATA = [] # Legal knowledge base LEGAL_CONCEPTS = { 'liability': ['negligence', 'strict liability', 'vicarious liability', 'product liability'], 'contract': ['breach', 'consideration', 'offer', 'acceptance', 'damages', 'specific performance'], 'criminal': ['mens rea', 'actus reus', 'intent', 'malice', 'premeditation'], 'procedure': ['jurisdiction', 'standing', 'statute of limitations', 'res judicata'], 'evidence': ['hearsay', 'relevance', 'privilege', 'burden of proof', 'admissibility'], 'constitutional': ['due process', 'equal protection', 'free speech', 'search and seizure'] } QUERY_PATTERNS = { 'precedent': ['case', 'precedent', 'ruling', 'held', 'decision'], 'statute_interpretation': ['statute', 'section', 'interpretation', 'meaning', 'definition'], 'factual': ['what happened', 'facts', 'circumstances', 'events'], 'procedure': ['how to', 'procedure', 'process', 'filing', 'requirements'] } def initialize_models(model_id: str, groq_api_key: str = None): """Initialize all models and components""" global MODEL, TOKENIZER, GROQ_CLIENT, NLP_MODEL, DEVICE try: nltk.download('punkt', quiet=True) nltk.download('stopwords', quiet=True) except: pass DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {DEVICE}") print(f"Loading model: {model_id}") TOKENIZER = AutoTokenizer.from_pretrained(model_id) MODEL = AutoModel.from_pretrained(model_id).to(DEVICE) MODEL.eval() if groq_api_key: GROQ_CLIENT = Groq(api_key=groq_api_key) try: NLP_MODEL = spacy.load("en_core_web_sm") except: print("SpaCy model not found, using basic NER") NLP_MODEL = None def create_embedding(text: str) -> np.ndarray: """Create dense embedding for text""" inputs = TOKENIZER(text, padding=True, truncation=True, max_length=512, return_tensors='pt').to(DEVICE) with torch.no_grad(): outputs = MODEL(**inputs) attention_mask = inputs['attention_mask'] token_embeddings = outputs.last_hidden_state input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) # Normalize embeddings embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) return embeddings.cpu().numpy()[0] def extract_legal_entities(text: str) -> List[Dict[str, Any]]: """Extract legal entities from text""" entities = [] if NLP_MODEL: doc = NLP_MODEL(text[:5000]) # Limit for performance for ent in doc.ents: if ent.label_ in ['PERSON', 'ORG', 'LAW', 'GPE']: entities.append({ 'text': ent.text, 'type': ent.label_, 'importance': 1.0 }) # Legal citations citation_pattern = r'\b\d+\s+[A-Z][a-z]+\.?\s+\d+\b' for match in re.finditer(citation_pattern, text): entities.append({ 'text': match.group(), 'type': 'case_citation', 'importance': 2.0 }) # Statute references statute_pattern = r'ยง\s*\d+[\.\d]*|\bSection\s+\d+' for match in re.finditer(statute_pattern, text): entities.append({ 'text': match.group(), 'type': 'statute', 'importance': 1.5 }) return entities def analyze_query(query: str) -> Dict[str, Any]: """Analyze query to understand intent""" query_lower = query.lower() # Classify query type query_type = 'general' for qtype, patterns in QUERY_PATTERNS.items(): if any(pattern in query_lower for pattern in patterns): query_type = qtype break # Extract entities entities = extract_legal_entities(query) # Extract key concepts key_concepts = [] for concept_category, concepts in LEGAL_CONCEPTS.items(): for concept in concepts: if concept in query_lower: key_concepts.append(concept) # Generate expanded queries expanded_queries = [query] # Concept expansion if key_concepts: expanded_queries.append(f"{query} {' '.join(key_concepts[:3])}") # Type-based expansion if query_type == 'precedent': expanded_queries.append(f"legal precedent case law {query}") elif query_type == 'statute_interpretation': expanded_queries.append(f"statutory interpretation meaning {query}") # HyDE - Hypothetical document generation if GROQ_CLIENT: hyde_doc = generate_hypothetical_document(query) if hyde_doc: expanded_queries.append(hyde_doc) return { 'original_query': query, 'query_type': query_type, 'entities': entities, 'key_concepts': key_concepts, 'expanded_queries': expanded_queries[:4] # Limit to 4 queries } def generate_hypothetical_document(query: str) -> Optional[str]: """Generate hypothetical answer document (HyDE technique)""" if not GROQ_CLIENT: return None try: prompt = f"""Generate a brief hypothetical legal document excerpt that would answer this question: {query} Write it as if it's from an actual legal case or statute. Be specific and use legal language. Keep it under 100 words.""" response = GROQ_CLIENT.chat.completions.create( messages=[ {"role": "system", "content": "You are a legal expert generating hypothetical legal text."}, {"role": "user", "content": prompt} ], model="llama-3.1-8b-instant", temperature=0.3, max_tokens=150 ) return response.choices[0].message.content except: return None def chunk_text_hierarchical(text: str, title: str = "") -> List[Dict[str, Any]]: """Create hierarchical chunks with legal structure awareness""" chunks = [] # Clean text text = re.sub(r'\s+', ' ', text) # Identify legal sections section_patterns = [ (r'(?i)\bFACTS?\b[:\s]', 'facts'), (r'(?i)\bHOLDING\b[:\s]', 'holding'), (r'(?i)\bREASONING\b[:\s]', 'reasoning'), (r'(?i)\bDISSENT\b[:\s]', 'dissent'), (r'(?i)\bCONCLUSION\b[:\s]', 'conclusion') ] sections = [] for pattern, section_type in section_patterns: matches = list(re.finditer(pattern, text)) for match in matches: sections.append((match.start(), section_type)) sections.sort(key=lambda x: x[0]) # Split into sentences import nltk try: sentences = nltk.sent_tokenize(text) except: sentences = text.split('. ') # Create chunks current_section = 'introduction' section_sentences = [] chunk_size = 500 # words for sent in sentences: # Check section type sent_pos = text.find(sent) for pos, stype in sections: if sent_pos >= pos: current_section = stype section_sentences.append(sent) # Create chunk when we have enough content chunk_text = ' '.join(section_sentences) if len(chunk_text.split()) >= chunk_size or len(section_sentences) >= 10: chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12] # Calculate importance importance = 1.0 section_weights = { 'holding': 2.0, 'conclusion': 1.8, 'reasoning': 1.5, 'facts': 1.2, 'dissent': 0.8 } importance *= section_weights.get(current_section, 1.0) # Entity importance entities = extract_legal_entities(chunk_text) if entities: entity_score = sum(e['importance'] for e in entities) / len(entities) importance *= (1 + entity_score * 0.5) chunks.append({ 'id': chunk_id, 'text': chunk_text, 'title': title, 'section_type': current_section, 'importance_score': importance, 'entities': entities, 'embedding': None # Will be filled during indexing }) section_sentences = [] # Add remaining sentences if section_sentences: chunk_text = ' '.join(section_sentences) chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12] chunks.append({ 'id': chunk_id, 'text': chunk_text, 'title': title, 'section_type': current_section, 'importance_score': 1.0, 'entities': extract_legal_entities(chunk_text), 'embedding': None }) return chunks def build_all_indices(chunks: List[Dict[str, Any]]): """Build all retrieval indices""" global DENSE_INDEX, BM25_INDEX, CONCEPT_GRAPH, TOKEN_TO_CHUNKS, CHUNKS_DATA CHUNKS_DATA = chunks print(f"Building indices for {len(chunks)} chunks...") # 1. Dense embeddings + FAISS index print("Building FAISS index...") embeddings = [] for chunk in tqdm(chunks, desc="Creating embeddings"): embedding = create_embedding(chunk['text']) chunk['embedding'] = embedding embeddings.append(embedding) embeddings_matrix = np.vstack(embeddings) DENSE_INDEX = faiss.IndexFlatIP(embeddings_matrix.shape[1]) # Inner product for normalized vectors DENSE_INDEX.add(embeddings_matrix.astype('float32')) # 2. BM25 index for sparse retrieval print("Building BM25 index...") tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks] BM25_INDEX = BM25Okapi(tokenized_corpus) # 3. ColBERT-style token index print("Building ColBERT token index...") TOKEN_TO_CHUNKS = defaultdict(set) for i, chunk in enumerate(chunks): # Simple tokenization for token-level matching tokens = chunk['text'].lower().split() for token in tokens: TOKEN_TO_CHUNKS[token].add(i) # 4. Legal concept graph print("Building legal concept graph...") CONCEPT_GRAPH = nx.Graph() for i, chunk in enumerate(chunks): CONCEPT_GRAPH.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score']) # Add edges between chunks with shared entities for j, other_chunk in enumerate(chunks[i+1:], i+1): shared_entities = set(e['text'] for e in chunk['entities']) & \ set(e['text'] for e in other_chunk['entities']) if shared_entities: CONCEPT_GRAPH.add_edge(i, j, weight=len(shared_entities)) print("All indices built successfully!") def multi_stage_retrieval(query_analysis: Dict[str, Any], top_k: int = 10) -> List[Tuple[Dict[str, Any], float]]: """Perform multi-stage retrieval combining all techniques""" candidates = {} print("Performing multi-stage retrieval...") # Stage 1: Dense retrieval with expanded queries print("Stage 1: Dense retrieval...") for query in query_analysis['expanded_queries'][:3]: query_emb = create_embedding(query) scores, indices = DENSE_INDEX.search( query_emb.reshape(1, -1).astype('float32'), top_k * 2 ) for idx, score in zip(indices[0], scores[0]): if idx < len(CHUNKS_DATA): chunk_id = CHUNKS_DATA[idx]['id'] if chunk_id not in candidates: candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}} candidates[chunk_id]['scores']['dense'] = float(score) # Stage 2: Sparse retrieval (BM25) print("Stage 2: Sparse retrieval...") query_tokens = query_analysis['original_query'].lower().split() bm25_scores = BM25_INDEX.get_scores(query_tokens) top_bm25_indices = np.argsort(bm25_scores)[-top_k*2:][::-1] for idx in top_bm25_indices: if idx < len(CHUNKS_DATA): chunk_id = CHUNKS_DATA[idx]['id'] if chunk_id not in candidates: candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}} candidates[chunk_id]['scores']['bm25'] = float(bm25_scores[idx]) # Stage 3: Entity-based retrieval print("Stage 3: Entity-based retrieval...") for entity in query_analysis['entities']: for chunk in CHUNKS_DATA: chunk_entity_texts = [e['text'].lower() for e in chunk['entities']] if entity['text'].lower() in chunk_entity_texts: chunk_id = chunk['id'] if chunk_id not in candidates: candidates[chunk_id] = {'chunk': chunk, 'scores': {}} candidates[chunk_id]['scores']['entity'] = \ candidates[chunk_id]['scores'].get('entity', 0) + entity['importance'] # Stage 4: Graph-based retrieval print("Stage 4: Graph-based retrieval...") if candidates and CONCEPT_GRAPH: seed_chunks = [] for chunk_id, data in list(candidates.items())[:5]: for i, chunk in enumerate(CHUNKS_DATA): if chunk['id'] == chunk_id: seed_chunks.append(i) break for seed_idx in seed_chunks: if seed_idx in CONCEPT_GRAPH: neighbors = list(CONCEPT_GRAPH.neighbors(seed_idx))[:3] for neighbor_idx in neighbors: if neighbor_idx < len(CHUNKS_DATA): chunk = CHUNKS_DATA[neighbor_idx] chunk_id = chunk['id'] if chunk_id not in candidates: candidates[chunk_id] = {'chunk': chunk, 'scores': {}} candidates[chunk_id]['scores']['graph'] = 0.5 # Combine scores print("Combining scores...") weights = {'dense': 0.35, 'bm25': 0.25, 'entity': 0.25, 'graph': 0.15} final_scores = [] for chunk_id, data in candidates.items(): chunk = data['chunk'] scores = data['scores'] final_score = 0 for method, weight in weights.items(): if method in scores: # Normalize scores if method == 'dense': normalized = (scores[method] + 1) / 2 # [-1, 1] to [0, 1] elif method == 'bm25': normalized = min(scores[method] / 10, 1) elif method == 'entity': normalized = min(scores[method] / 3, 1) else: normalized = scores[method] final_score += weight * normalized # Boost by importance and section relevance final_score *= chunk['importance_score'] if query_analysis['query_type'] == 'precedent' and chunk['section_type'] == 'holding': final_score *= 1.5 elif query_analysis['query_type'] == 'factual' and chunk['section_type'] == 'facts': final_score *= 1.5 final_scores.append((chunk, final_score)) # Sort and return top-k final_scores.sort(key=lambda x: x[1], reverse=True) return final_scores[:top_k] def generate_answer_with_reasoning(query: str, retrieved_chunks: List[Tuple[Dict[str, Any], float]]) -> Dict[str, Any]: """Generate answer with legal reasoning""" if not GROQ_CLIENT: return {'error': 'Groq client not initialized'} # Prepare context context_parts = [] for i, (chunk, score) in enumerate(retrieved_chunks, 1): entities = ', '.join([e['text'] for e in chunk['entities'][:3]]) context_parts.append(f""" Document {i} [{chunk['title']}] - Relevance: {score:.2f} Section: {chunk['section_type']} Key Entities: {entities} Content: {chunk['text'][:800]} """) context = "\n---\n".join(context_parts) system_prompt = """You are an expert legal analyst. Provide thorough legal analysis using the IRAC method: 1. ISSUE: Identify the legal issue(s) 2. RULE: State the applicable legal rules/precedents 3. APPLICATION: Apply the rules to the facts 4. CONCLUSION: Provide a clear conclusion CRITICAL: Base ALL responses on the provided document excerpts only. Quote directly when making claims. If information is not in the excerpts, state "This information is not provided in the available documents." """ user_prompt = f"""Query: {query} Retrieved Legal Documents: {context} Please provide a comprehensive legal analysis using IRAC method. Cite the documents when making claims.""" try: response = GROQ_CLIENT.chat.completions.create( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], model="llama-3.1-8b-instant", temperature=0.1, max_tokens=1000 ) answer = response.choices[0].message.content # Calculate confidence avg_score = sum(score for _, score in retrieved_chunks[:3]) / min(3, len(retrieved_chunks)) confidence = min(avg_score * 100, 100) return { 'answer': answer, 'confidence': confidence, 'sources': [ { 'chunk_id': chunk['id'], 'title': chunk['title'], 'section': chunk['section_type'], 'relevance_score': float(score), 'excerpt': chunk['text'][:200] + '...', 'entities': [e['text'] for e in chunk['entities'][:5]] } for chunk, score in retrieved_chunks ] } except Exception as e: return { 'error': f'Error generating answer: {str(e)}', 'sources': [{'chunk': c['text'][:200], 'score': s} for c, s in retrieved_chunks[:3]] } # Main functions for external use def process_documents(documents: List[Dict[str, str]]) -> Dict[str, Any]: """Process documents and build indices""" all_chunks = [] for doc in documents: chunks = chunk_text_hierarchical(doc['text'], doc.get('title', 'Document')) all_chunks.extend(chunks) build_all_indices(all_chunks) return { 'success': True, 'chunk_count': len(all_chunks), 'message': f'Processed {len(documents)} documents into {len(all_chunks)} chunks' } def query_documents(query: str, top_k: int = 5) -> Dict[str, Any]: """Main query function - takes query, returns answer with sources""" if not CHUNKS_DATA: return {'error': 'No documents indexed. Call process_documents first.'} # Analyze query query_analysis = analyze_query(query) # Multi-stage retrieval retrieved_chunks = multi_stage_retrieval(query_analysis, top_k) if not retrieved_chunks: return { 'error': 'No relevant documents found', 'query_analysis': query_analysis } # Generate answer result = generate_answer_with_reasoning(query, retrieved_chunks) result['query_analysis'] = query_analysis return result def search_chunks_simple(query: str, top_k: int = 3) -> List[Dict[str, Any]]: """Simple search function for compatibility""" if not CHUNKS_DATA: return [] query_analysis = analyze_query(query) retrieved_chunks = multi_stage_retrieval(query_analysis, top_k) results = [] for chunk, score in retrieved_chunks: results.append({ 'chunk': { 'id': chunk['id'], 'text': chunk['text'], 'title': chunk['title'] }, 'score': score }) return results def generate_conservative_answer(query: str, context_chunks: List[Dict[str, Any]]) -> str: """Generate conservative answer - for compatibility""" if not context_chunks: return "No relevant information found." # Convert format retrieved_chunks = [(chunk['chunk'], chunk['score']) for chunk in context_chunks] result = generate_answer_with_reasoning(query, retrieved_chunks) if 'error' in result: return result['error'] return result.get('answer', 'Unable to generate answer.')