Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import json | |
| import faiss | |
| import re | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from groq import Groq | |
| import os | |
| from typing import List, Dict, Optional | |
| import logging | |
| import httpx | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="LexNepal AI API", | |
| description="Advanced Legal Intelligence API for Nepal Legal Code", | |
| version="1.0.0", | |
| docs_url="/", | |
| redoc_url="/redoc" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class QueryRequest(BaseModel): | |
| query: str | |
| max_sources: Optional[int] = 10 | |
| class Source(BaseModel): | |
| law: str | |
| section: str | |
| section_title: str | |
| text: str | |
| rel_score: float | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| sources: List[Source] | |
| query: str | |
| total_candidates: int | |
| class StatsResponse(BaseModel): | |
| total_provisions: int | |
| total_laws: int | |
| vector_dimensions: int | |
| embedding_model: str | |
| reranking_model: str | |
| llm_model: str | |
| class HealthResponse(BaseModel): | |
| status: str | |
| models_loaded: bool | |
| message: Optional[str] = None | |
| _bi_encoder = None | |
| _cross_encoder = None | |
| _groq_client = None | |
| _index = None | |
| _metadata = None | |
| def get_bi_encoder(): | |
| global _bi_encoder | |
| if _bi_encoder is None: | |
| logger.info("Loading bi-encoder (MPNet)...") | |
| _bi_encoder = SentenceTransformer("all-mpnet-base-v2") | |
| logger.info("✅ Bi-encoder loaded successfully") | |
| return _bi_encoder | |
| def get_cross_encoder(): | |
| global _cross_encoder | |
| if _cross_encoder is None: | |
| logger.info("Loading cross-encoder...") | |
| _cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| logger.info("✅ Cross-encoder loaded successfully") | |
| return _cross_encoder | |
| def get_groq_client(): | |
| global _groq_client | |
| if _groq_client is None: | |
| logger.info("Initializing Groq client...") | |
| # Get API key from environment ONLY (no fallback) | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| logger.error("❌ GROQ_API_KEY not found in environment") | |
| raise HTTPException( | |
| status_code=503, | |
| detail="GROQ_API_KEY not configured. Please set it in Hugging Face Space secrets." | |
| ) | |
| try: | |
| _groq_client = Groq(api_key=groq_api_key) | |
| logger.info("✅ Groq client initialized successfully") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to initialize Groq client: {e}") | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Failed to initialize Groq client: {str(e)}" | |
| ) | |
| return _groq_client | |
| def get_index(): | |
| global _index | |
| if _index is None: | |
| logger.info("Loading embeddings and creating FAISS index...") | |
| try: | |
| embeddings = np.load("final_legal_embeddings.npy") | |
| logger.info(f"Embeddings shape: {embeddings.shape}") | |
| _index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| _index.add(embeddings.astype('float32')) | |
| logger.info(f"✅ FAISS index created with {embeddings.shape[0]} vectors") | |
| except FileNotFoundError: | |
| logger.error("❌ Embeddings file not found") | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Embeddings file not found. Please upload final_legal_embeddings.npy" | |
| ) | |
| return _index | |
| def get_metadata(): | |
| global _metadata | |
| if _metadata is None: | |
| logger.info("Loading metadata...") | |
| try: | |
| with open("final_legal_laws_metadata.json", "r", encoding="utf-8") as f: | |
| _metadata = json.load(f) | |
| logger.info(f"✅ Loaded {len(_metadata)} legal provisions") | |
| except FileNotFoundError: | |
| logger.error("❌ Metadata file not found") | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Metadata file not found. Please upload final_legal_laws_metadata.json" | |
| ) | |
| return _metadata | |
| def get_premium_context(query: str, max_sources: int = 10) -> List[Dict]: | |
| try: | |
| bi_encoder = get_bi_encoder() | |
| cross_encoder = get_cross_encoder() | |
| index = get_index() | |
| metadata = get_metadata() | |
| # Stage 1: Encode query | |
| query_embedding = bi_encoder.encode([query], convert_to_numpy=True) | |
| # Stage 2: Dense retrieval | |
| _, indices = index.search(query_embedding.astype('float32'), 25) | |
| candidates = [] | |
| seen = set() | |
| for i in indices[0]: | |
| if i != -1 and i < len(metadata): | |
| candidates.append(metadata[i].copy()) | |
| seen.add(i) | |
| # Stage 3: Keyword boosting | |
| numbers = re.findall(r'\d+', query) | |
| if numbers: | |
| for i, item in enumerate(metadata): | |
| if any(str(item.get('section', '')) == n for n in numbers): | |
| if i not in seen: | |
| candidates.append(item.copy()) | |
| seen.add(i) | |
| # Stage 4: Cross-encoder reranking | |
| if candidates: | |
| pairs = [ | |
| [query, f"{c.get('law', '')} {c.get('section_title', '')} {c.get('text', '')}"] | |
| for c in candidates | |
| ] | |
| scores = cross_encoder.predict(pairs) | |
| for i, c in enumerate(candidates): | |
| c['rel_score'] = float(scores[i]) | |
| candidates = sorted(candidates, key=lambda x: x['rel_score'], reverse=True)[:max_sources] | |
| logger.info(f"Retrieved {len(candidates)} relevant candidates") | |
| return candidates | |
| except Exception as e: | |
| logger.error(f"Error in context retrieval: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Context retrieval error: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| metadata = get_metadata() | |
| models_loaded = True | |
| message = f"API is healthy. {len(metadata)} provisions loaded." | |
| except Exception as e: | |
| models_loaded = False | |
| message = f"Error: {str(e)}" | |
| return { | |
| "status": "healthy" if models_loaded else "unhealthy", | |
| "models_loaded": models_loaded, | |
| "message": message | |
| } | |
| async def get_statistics(): | |
| """Get database statistics""" | |
| try: | |
| metadata = get_metadata() | |
| unique_laws = len(set(d.get('law', '') for d in metadata)) | |
| return { | |
| "total_provisions": len(metadata), | |
| "total_laws": unique_laws, | |
| "vector_dimensions": 768, | |
| "embedding_model": "all-mpnet-base-v2", | |
| "reranking_model": "ms-marco-MiniLM-L-6-v2", | |
| "llm_model": "llama-3.3-70b-versatile" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting stats: {str(e)}") | |
| raise HTTPException(status_code=503, detail=str(e)) | |
| async def process_legal_query(request: QueryRequest): | |
| """Process legal query with RAG pipeline""" | |
| # Validation | |
| if not request.query.strip(): | |
| raise HTTPException(status_code=400, detail="Query cannot be empty") | |
| if len(request.query) > 1000: | |
| raise HTTPException(status_code=400, detail="Query too long (max 1000 characters)") | |
| try: | |
| logger.info(f"Processing query: {request.query[:100]}...") | |
| # Get relevant context | |
| candidates = get_premium_context(request.query, request.max_sources) | |
| if not candidates: | |
| return { | |
| "answer": "No relevant legal provisions found in the database for your query. Please try rephrasing or consult a legal professional.", | |
| "sources": [], | |
| "query": request.query, | |
| "total_candidates": 0 | |
| } | |
| # Build context string | |
| context_str = "\n\n".join([ | |
| f"[{d['law']} Section {d['section']}]: {d['text']}" | |
| for d in candidates | |
| ]) | |
| # System prompt | |
| system_prompt = """You are an Elite Legal Advisor specializing in Nepal law. | |
| OPERATIONAL MANDATE: | |
| 1. Answer STRICTLY from provided legal text | |
| 2. If information is absent, state: "No specific provision found in current database" | |
| 3. Always cite exact Law name and Section number | |
| 4. Use formal, authoritative legal language | |
| 5. NEVER hallucinate or infer beyond provided text | |
| 6. Maintain zero-tolerance policy for speculation | |
| When citing, use format: "According to [Law Name], Section [Number]..." | |
| Provide clear, structured answers with proper legal citations.""" | |
| # Generate response using Groq | |
| logger.info("Generating LLM response...") | |
| groq_client = get_groq_client() | |
| response = groq_client.chat.completions.create( | |
| model="llama-3.3-70b-versatile", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Legal Context:\n{context_str}\n\nQuery: {request.query}"} | |
| ], | |
| temperature=0, | |
| max_tokens=1500 | |
| ) | |
| answer = response.choices[0].message.content | |
| # Format sources | |
| sources = [ | |
| Source( | |
| law=d['law'], | |
| section=str(d['section']), | |
| section_title=d['section_title'], | |
| text=d['text'], | |
| rel_score=d['rel_score'] | |
| ) | |
| for d in candidates | |
| ] | |
| logger.info(f"✅ Query processed successfully with {len(sources)} sources") | |
| return { | |
| "answer": answer, | |
| "sources": sources, | |
| "query": request.query, | |
| "total_candidates": len(candidates) | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing query: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Query processing error: {str(e)}") | |
| async def root(): | |
| """Root endpoint - API info""" | |
| return { | |
| "message": "🇳🇵 LexNepal AI API is running", | |
| "version": "1.0.0", | |
| "description": "Advanced Legal Intelligence for Nepal Legal Code", | |
| "endpoints": { | |
| "docs": "/ (Swagger UI)", | |
| "health": "/health (GET)", | |
| "stats": "/stats (GET)", | |
| "query": "/query (POST)" | |
| }, | |
| "technology": "RAG with Hybrid Retrieval + Cross-Encoder Reranking", | |
| "support": "https://huggingface.co/spaces/yamraj047/lexnepal-api" | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |