Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| from typing import List, Dict, Any, Tuple, Optional | |
| import faiss | |
| import hashlib | |
| from tqdm import tqdm | |
| from groq import Groq | |
| import re | |
| import nltk | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import networkx as nx | |
| from collections import defaultdict | |
| import spacy | |
| from rank_bm25 import BM25Okapi | |
| # Global variables for models | |
| MODEL = None | |
| TOKENIZER = None | |
| GROQ_CLIENT = None | |
| NLP_MODEL = None | |
| DEVICE = None | |
| # Global indices | |
| DENSE_INDEX = None | |
| BM25_INDEX = None | |
| CONCEPT_GRAPH = None | |
| TOKEN_TO_CHUNKS = None | |
| CHUNKS_DATA = [] | |
| # Legal knowledge base | |
| LEGAL_CONCEPTS = { | |
| 'liability': ['negligence', 'strict liability', 'vicarious liability', 'product liability'], | |
| 'contract': ['breach', 'consideration', 'offer', 'acceptance', 'damages', 'specific performance'], | |
| 'criminal': ['mens rea', 'actus reus', 'intent', 'malice', 'premeditation'], | |
| 'procedure': ['jurisdiction', 'standing', 'statute of limitations', 'res judicata'], | |
| 'evidence': ['hearsay', 'relevance', 'privilege', 'burden of proof', 'admissibility'], | |
| 'constitutional': ['due process', 'equal protection', 'free speech', 'search and seizure'] | |
| } | |
| QUERY_PATTERNS = { | |
| 'precedent': ['case', 'precedent', 'ruling', 'held', 'decision'], | |
| 'statute_interpretation': ['statute', 'section', 'interpretation', 'meaning', 'definition'], | |
| 'factual': ['what happened', 'facts', 'circumstances', 'events'], | |
| 'procedure': ['how to', 'procedure', 'process', 'filing', 'requirements'] | |
| } | |
| def initialize_models(model_id: str, groq_api_key: str = None): | |
| """Initialize all models and components""" | |
| global MODEL, TOKENIZER, GROQ_CLIENT, NLP_MODEL, DEVICE | |
| try: | |
| nltk.download('punkt', quiet=True) | |
| nltk.download('stopwords', quiet=True) | |
| except: | |
| pass | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {DEVICE}") | |
| print(f"Loading model: {model_id}") | |
| TOKENIZER = AutoTokenizer.from_pretrained(model_id) | |
| MODEL = AutoModel.from_pretrained(model_id).to(DEVICE) | |
| MODEL.eval() | |
| if groq_api_key: | |
| GROQ_CLIENT = Groq(api_key=groq_api_key) | |
| try: | |
| NLP_MODEL = spacy.load("en_core_web_sm") | |
| except: | |
| print("SpaCy model not found, using basic NER") | |
| NLP_MODEL = None | |
| def create_embedding(text: str) -> np.ndarray: | |
| """Create dense embedding for text""" | |
| inputs = TOKENIZER(text, padding=True, truncation=True, | |
| max_length=512, return_tensors='pt').to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = MODEL(**inputs) | |
| attention_mask = inputs['attention_mask'] | |
| token_embeddings = outputs.last_hidden_state | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| # Normalize embeddings | |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
| return embeddings.cpu().numpy()[0] | |
| def extract_legal_entities(text: str) -> List[Dict[str, Any]]: | |
| """Extract legal entities from text""" | |
| entities = [] | |
| if NLP_MODEL: | |
| doc = NLP_MODEL(text[:5000]) # Limit for performance | |
| for ent in doc.ents: | |
| if ent.label_ in ['PERSON', 'ORG', 'LAW', 'GPE']: | |
| entities.append({ | |
| 'text': ent.text, | |
| 'type': ent.label_, | |
| 'importance': 1.0 | |
| }) | |
| # Legal citations | |
| citation_pattern = r'\b\d+\s+[A-Z][a-z]+\.?\s+\d+\b' | |
| for match in re.finditer(citation_pattern, text): | |
| entities.append({ | |
| 'text': match.group(), | |
| 'type': 'case_citation', | |
| 'importance': 2.0 | |
| }) | |
| # Statute references | |
| statute_pattern = r'§\s*\d+[\.\d]*|\bSection\s+\d+' | |
| for match in re.finditer(statute_pattern, text): | |
| entities.append({ | |
| 'text': match.group(), | |
| 'type': 'statute', | |
| 'importance': 1.5 | |
| }) | |
| return entities | |
| def analyze_query(query: str) -> Dict[str, Any]: | |
| """Analyze query to understand intent""" | |
| query_lower = query.lower() | |
| # Classify query type | |
| query_type = 'general' | |
| for qtype, patterns in QUERY_PATTERNS.items(): | |
| if any(pattern in query_lower for pattern in patterns): | |
| query_type = qtype | |
| break | |
| # Extract entities | |
| entities = extract_legal_entities(query) | |
| # Extract key concepts | |
| key_concepts = [] | |
| for concept_category, concepts in LEGAL_CONCEPTS.items(): | |
| for concept in concepts: | |
| if concept in query_lower: | |
| key_concepts.append(concept) | |
| # Generate expanded queries | |
| expanded_queries = [query] | |
| # Concept expansion | |
| if key_concepts: | |
| expanded_queries.append(f"{query} {' '.join(key_concepts[:3])}") | |
| # Type-based expansion | |
| if query_type == 'precedent': | |
| expanded_queries.append(f"legal precedent case law {query}") | |
| elif query_type == 'statute_interpretation': | |
| expanded_queries.append(f"statutory interpretation meaning {query}") | |
| # HyDE - Hypothetical document generation | |
| if GROQ_CLIENT: | |
| hyde_doc = generate_hypothetical_document(query) | |
| if hyde_doc: | |
| expanded_queries.append(hyde_doc) | |
| return { | |
| 'original_query': query, | |
| 'query_type': query_type, | |
| 'entities': entities, | |
| 'key_concepts': key_concepts, | |
| 'expanded_queries': expanded_queries[:4] # Limit to 4 queries | |
| } | |
| def generate_hypothetical_document(query: str) -> Optional[str]: | |
| """Generate hypothetical answer document (HyDE technique)""" | |
| if not GROQ_CLIENT: | |
| return None | |
| try: | |
| prompt = f"""Generate a brief hypothetical legal document excerpt that would answer this question: {query} | |
| Write it as if it's from an actual legal case or statute. Be specific and use legal language. | |
| Keep it under 100 words.""" | |
| response = GROQ_CLIENT.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a legal expert generating hypothetical legal text."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model="llama-3.1-8b-instant", | |
| temperature=0.3, | |
| max_tokens=150 | |
| ) | |
| return response.choices[0].message.content | |
| except: | |
| return None | |
| def chunk_text_hierarchical(text: str, title: str = "") -> List[Dict[str, Any]]: | |
| """Create hierarchical chunks with legal structure awareness""" | |
| chunks = [] | |
| # Clean text | |
| text = re.sub(r'\s+', ' ', text) | |
| # Identify legal sections | |
| section_patterns = [ | |
| (r'(?i)\bFACTS?\b[:\s]', 'facts'), | |
| (r'(?i)\bHOLDING\b[:\s]', 'holding'), | |
| (r'(?i)\bREASONING\b[:\s]', 'reasoning'), | |
| (r'(?i)\bDISSENT\b[:\s]', 'dissent'), | |
| (r'(?i)\bCONCLUSION\b[:\s]', 'conclusion') | |
| ] | |
| sections = [] | |
| for pattern, section_type in section_patterns: | |
| matches = list(re.finditer(pattern, text)) | |
| for match in matches: | |
| sections.append((match.start(), section_type)) | |
| sections.sort(key=lambda x: x[0]) | |
| # Split into sentences | |
| import nltk | |
| try: | |
| sentences = nltk.sent_tokenize(text) | |
| except: | |
| sentences = text.split('. ') | |
| # Create chunks | |
| current_section = 'introduction' | |
| section_sentences = [] | |
| chunk_size = 500 # words | |
| for sent in sentences: | |
| # Check section type | |
| sent_pos = text.find(sent) | |
| for pos, stype in sections: | |
| if sent_pos >= pos: | |
| current_section = stype | |
| section_sentences.append(sent) | |
| # Create chunk when we have enough content | |
| chunk_text = ' '.join(section_sentences) | |
| if len(chunk_text.split()) >= chunk_size or len(section_sentences) >= 10: | |
| chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12] | |
| # Calculate importance | |
| importance = 1.0 | |
| section_weights = { | |
| 'holding': 2.0, 'conclusion': 1.8, 'reasoning': 1.5, | |
| 'facts': 1.2, 'dissent': 0.8 | |
| } | |
| importance *= section_weights.get(current_section, 1.0) | |
| # Entity importance | |
| entities = extract_legal_entities(chunk_text) | |
| if entities: | |
| entity_score = sum(e['importance'] for e in entities) / len(entities) | |
| importance *= (1 + entity_score * 0.5) | |
| chunks.append({ | |
| 'id': chunk_id, | |
| 'text': chunk_text, | |
| 'title': title, | |
| 'section_type': current_section, | |
| 'importance_score': importance, | |
| 'entities': entities, | |
| 'embedding': None # Will be filled during indexing | |
| }) | |
| section_sentences = [] | |
| # Add remaining sentences | |
| if section_sentences: | |
| chunk_text = ' '.join(section_sentences) | |
| chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12] | |
| chunks.append({ | |
| 'id': chunk_id, | |
| 'text': chunk_text, | |
| 'title': title, | |
| 'section_type': current_section, | |
| 'importance_score': 1.0, | |
| 'entities': extract_legal_entities(chunk_text), | |
| 'embedding': None | |
| }) | |
| return chunks | |
| def build_all_indices(chunks: List[Dict[str, Any]]): | |
| """Build all retrieval indices""" | |
| global DENSE_INDEX, BM25_INDEX, CONCEPT_GRAPH, TOKEN_TO_CHUNKS, CHUNKS_DATA | |
| CHUNKS_DATA = chunks | |
| print(f"Building indices for {len(chunks)} chunks...") | |
| # 1. Dense embeddings + FAISS index | |
| print("Building FAISS index...") | |
| embeddings = [] | |
| for chunk in tqdm(chunks, desc="Creating embeddings"): | |
| embedding = create_embedding(chunk['text']) | |
| chunk['embedding'] = embedding | |
| embeddings.append(embedding) | |
| embeddings_matrix = np.vstack(embeddings) | |
| DENSE_INDEX = faiss.IndexFlatIP(embeddings_matrix.shape[1]) # Inner product for normalized vectors | |
| DENSE_INDEX.add(embeddings_matrix.astype('float32')) | |
| # 2. BM25 index for sparse retrieval | |
| print("Building BM25 index...") | |
| tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks] | |
| BM25_INDEX = BM25Okapi(tokenized_corpus) | |
| # 3. ColBERT-style token index | |
| print("Building ColBERT token index...") | |
| TOKEN_TO_CHUNKS = defaultdict(set) | |
| for i, chunk in enumerate(chunks): | |
| # Simple tokenization for token-level matching | |
| tokens = chunk['text'].lower().split() | |
| for token in tokens: | |
| TOKEN_TO_CHUNKS[token].add(i) | |
| # 4. Legal concept graph | |
| print("Building legal concept graph...") | |
| CONCEPT_GRAPH = nx.Graph() | |
| for i, chunk in enumerate(chunks): | |
| CONCEPT_GRAPH.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score']) | |
| # Add edges between chunks with shared entities | |
| for j, other_chunk in enumerate(chunks[i+1:], i+1): | |
| shared_entities = set(e['text'] for e in chunk['entities']) & \ | |
| set(e['text'] for e in other_chunk['entities']) | |
| if shared_entities: | |
| CONCEPT_GRAPH.add_edge(i, j, weight=len(shared_entities)) | |
| print("All indices built successfully!") | |
| def multi_stage_retrieval(query_analysis: Dict[str, Any], top_k: int = 10) -> List[Tuple[Dict[str, Any], float]]: | |
| """Perform multi-stage retrieval combining all techniques""" | |
| candidates = {} | |
| print("Performing multi-stage retrieval...") | |
| # Stage 1: Dense retrieval with expanded queries | |
| print("Stage 1: Dense retrieval...") | |
| for query in query_analysis['expanded_queries'][:3]: | |
| query_emb = create_embedding(query) | |
| scores, indices = DENSE_INDEX.search( | |
| query_emb.reshape(1, -1).astype('float32'), | |
| top_k * 2 | |
| ) | |
| for idx, score in zip(indices[0], scores[0]): | |
| if idx < len(CHUNKS_DATA): | |
| chunk_id = CHUNKS_DATA[idx]['id'] | |
| if chunk_id not in candidates: | |
| candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}} | |
| candidates[chunk_id]['scores']['dense'] = float(score) | |
| # Stage 2: Sparse retrieval (BM25) | |
| print("Stage 2: Sparse retrieval...") | |
| query_tokens = query_analysis['original_query'].lower().split() | |
| bm25_scores = BM25_INDEX.get_scores(query_tokens) | |
| top_bm25_indices = np.argsort(bm25_scores)[-top_k*2:][::-1] | |
| for idx in top_bm25_indices: | |
| if idx < len(CHUNKS_DATA): | |
| chunk_id = CHUNKS_DATA[idx]['id'] | |
| if chunk_id not in candidates: | |
| candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}} | |
| candidates[chunk_id]['scores']['bm25'] = float(bm25_scores[idx]) | |
| # Stage 3: Entity-based retrieval | |
| print("Stage 3: Entity-based retrieval...") | |
| for entity in query_analysis['entities']: | |
| for chunk in CHUNKS_DATA: | |
| chunk_entity_texts = [e['text'].lower() for e in chunk['entities']] | |
| if entity['text'].lower() in chunk_entity_texts: | |
| chunk_id = chunk['id'] | |
| if chunk_id not in candidates: | |
| candidates[chunk_id] = {'chunk': chunk, 'scores': {}} | |
| candidates[chunk_id]['scores']['entity'] = \ | |
| candidates[chunk_id]['scores'].get('entity', 0) + entity['importance'] | |
| # Stage 4: Graph-based retrieval | |
| print("Stage 4: Graph-based retrieval...") | |
| if candidates and CONCEPT_GRAPH: | |
| seed_chunks = [] | |
| for chunk_id, data in list(candidates.items())[:5]: | |
| for i, chunk in enumerate(CHUNKS_DATA): | |
| if chunk['id'] == chunk_id: | |
| seed_chunks.append(i) | |
| break | |
| for seed_idx in seed_chunks: | |
| if seed_idx in CONCEPT_GRAPH: | |
| neighbors = list(CONCEPT_GRAPH.neighbors(seed_idx))[:3] | |
| for neighbor_idx in neighbors: | |
| if neighbor_idx < len(CHUNKS_DATA): | |
| chunk = CHUNKS_DATA[neighbor_idx] | |
| chunk_id = chunk['id'] | |
| if chunk_id not in candidates: | |
| candidates[chunk_id] = {'chunk': chunk, 'scores': {}} | |
| candidates[chunk_id]['scores']['graph'] = 0.5 | |
| # Combine scores | |
| print("Combining scores...") | |
| weights = {'dense': 0.35, 'bm25': 0.25, 'entity': 0.25, 'graph': 0.15} | |
| final_scores = [] | |
| for chunk_id, data in candidates.items(): | |
| chunk = data['chunk'] | |
| scores = data['scores'] | |
| final_score = 0 | |
| for method, weight in weights.items(): | |
| if method in scores: | |
| # Normalize scores | |
| if method == 'dense': | |
| normalized = (scores[method] + 1) / 2 # [-1, 1] to [0, 1] | |
| elif method == 'bm25': | |
| normalized = min(scores[method] / 10, 1) | |
| elif method == 'entity': | |
| normalized = min(scores[method] / 3, 1) | |
| else: | |
| normalized = scores[method] | |
| final_score += weight * normalized | |
| # Boost by importance and section relevance | |
| final_score *= chunk['importance_score'] | |
| if query_analysis['query_type'] == 'precedent' and chunk['section_type'] == 'holding': | |
| final_score *= 1.5 | |
| elif query_analysis['query_type'] == 'factual' and chunk['section_type'] == 'facts': | |
| final_score *= 1.5 | |
| final_scores.append((chunk, final_score)) | |
| # Sort and return top-k | |
| final_scores.sort(key=lambda x: x[1], reverse=True) | |
| return final_scores[:top_k] | |
| def generate_answer_with_reasoning(query: str, retrieved_chunks: List[Tuple[Dict[str, Any], float]]) -> Dict[str, Any]: | |
| """Generate answer with legal reasoning""" | |
| if not GROQ_CLIENT: | |
| return {'error': 'Groq client not initialized'} | |
| # Prepare context | |
| context_parts = [] | |
| for i, (chunk, score) in enumerate(retrieved_chunks, 1): | |
| entities = ', '.join([e['text'] for e in chunk['entities'][:3]]) | |
| context_parts.append(f""" | |
| Document {i} [{chunk['title']}] - Relevance: {score:.2f} | |
| Section: {chunk['section_type']} | |
| Key Entities: {entities} | |
| Content: {chunk['text'][:800]} | |
| """) | |
| context = "\n---\n".join(context_parts) | |
| system_prompt = """You are an expert legal analyst. Provide thorough legal analysis using the IRAC method: | |
| 1. ISSUE: Identify the legal issue(s) | |
| 2. RULE: State the applicable legal rules/precedents | |
| 3. APPLICATION: Apply the rules to the facts | |
| 4. CONCLUSION: Provide a clear conclusion | |
| CRITICAL: Base ALL responses on the provided document excerpts only. Quote directly when making claims. | |
| If information is not in the excerpts, state "This information is not provided in the available documents." | |
| """ | |
| user_prompt = f"""Query: {query} | |
| Retrieved Legal Documents: | |
| {context} | |
| Please provide a comprehensive legal analysis using IRAC method. Cite the documents when making claims.""" | |
| try: | |
| response = GROQ_CLIENT.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| model="llama-3.1-8b-instant", | |
| temperature=0.1, | |
| max_tokens=1000 | |
| ) | |
| answer = response.choices[0].message.content | |
| # Calculate confidence | |
| avg_score = sum(score for _, score in retrieved_chunks[:3]) / min(3, len(retrieved_chunks)) | |
| confidence = min(avg_score * 100, 100) | |
| return { | |
| 'answer': answer, | |
| 'confidence': confidence, | |
| 'sources': [ | |
| { | |
| 'chunk_id': chunk['id'], | |
| 'title': chunk['title'], | |
| 'section': chunk['section_type'], | |
| 'relevance_score': float(score), | |
| 'excerpt': chunk['text'][:200] + '...', | |
| 'entities': [e['text'] for e in chunk['entities'][:5]] | |
| } | |
| for chunk, score in retrieved_chunks | |
| ] | |
| } | |
| except Exception as e: | |
| return { | |
| 'error': f'Error generating answer: {str(e)}', | |
| 'sources': [{'chunk': c['text'][:200], 'score': s} for c, s in retrieved_chunks[:3]] | |
| } | |
| # Main functions for external use | |
| def process_documents(documents: List[Dict[str, str]]) -> Dict[str, Any]: | |
| """Process documents and build indices""" | |
| all_chunks = [] | |
| for doc in documents: | |
| chunks = chunk_text_hierarchical(doc['text'], doc.get('title', 'Document')) | |
| all_chunks.extend(chunks) | |
| build_all_indices(all_chunks) | |
| return { | |
| 'success': True, | |
| 'chunk_count': len(all_chunks), | |
| 'message': f'Processed {len(documents)} documents into {len(all_chunks)} chunks' | |
| } | |
| def query_documents(query: str, top_k: int = 5) -> Dict[str, Any]: | |
| """Main query function - takes query, returns answer with sources""" | |
| if not CHUNKS_DATA: | |
| return {'error': 'No documents indexed. Call process_documents first.'} | |
| # Analyze query | |
| query_analysis = analyze_query(query) | |
| # Multi-stage retrieval | |
| retrieved_chunks = multi_stage_retrieval(query_analysis, top_k) | |
| if not retrieved_chunks: | |
| return { | |
| 'error': 'No relevant documents found', | |
| 'query_analysis': query_analysis | |
| } | |
| # Generate answer | |
| result = generate_answer_with_reasoning(query, retrieved_chunks) | |
| result['query_analysis'] = query_analysis | |
| return result | |
| def search_chunks_simple(query: str, top_k: int = 3) -> List[Dict[str, Any]]: | |
| """Simple search function for compatibility""" | |
| if not CHUNKS_DATA: | |
| return [] | |
| query_analysis = analyze_query(query) | |
| retrieved_chunks = multi_stage_retrieval(query_analysis, top_k) | |
| results = [] | |
| for chunk, score in retrieved_chunks: | |
| results.append({ | |
| 'chunk': { | |
| 'id': chunk['id'], | |
| 'text': chunk['text'], | |
| 'title': chunk['title'] | |
| }, | |
| 'score': score | |
| }) | |
| return results | |
| def generate_conservative_answer(query: str, context_chunks: List[Dict[str, Any]]) -> str: | |
| """Generate conservative answer - for compatibility""" | |
| if not context_chunks: | |
| return "No relevant information found." | |
| # Convert format | |
| retrieved_chunks = [(chunk['chunk'], chunk['score']) for chunk in context_chunks] | |
| result = generate_answer_with_reasoning(query, retrieved_chunks) | |
| if 'error' in result: | |
| return result['error'] | |
| return result.get('answer', 'Unable to generate answer.') |