Spaces:
Running
Running
File size: 7,020 Bytes
56e1ad9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
"""
LexiBot FastAPI Backend
Headless RAG API for Indian Legal Information
Replaces the legacy Telegram bot (main.py) with a REST API.
Designed for deployment on Hugging Face Spaces.
"""
import os
from typing import Dict
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_pinecone import PineconeEmbeddings, PineconeVectorStore
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate
load_dotenv()
# Configuration
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME")
# Global state
vector_store = None
llm = None
embeddings = None
session_memories: Dict[str, ConversationBufferWindowMemory] = {}
# Pydantic Models
class ChatRequest(BaseModel):
message: str
session_id: str
class ChatResponse(BaseModel):
response: str
sources: list[str]
class HealthResponse(BaseModel):
status: str
# Legal RAG Prompt
LEGAL_PROMPT = PromptTemplate(
template="""You are LexiBot, an AI legal assistant specializing in Indian law.
IMPORTANT GUIDELINES:
- Provide accurate information based ONLY on the context provided
- If the context doesn't contain relevant information, say "I don't have specific information about that in my legal database"
- Always recommend consulting a qualified lawyer for specific legal matters
- Be clear, concise, and use simple language
- When citing laws, mention the specific Act name and section number
CONTEXT FROM LEGAL DATABASE:
{context}
USER QUESTION: {question}
Provide a helpful, accurate response based on the legal context above:""",
input_variables=["context", "question"]
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize resources on startup."""
global vector_store, llm, embeddings
print("🚀 Initializing LexiBot API...")
# Validate environment
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY not set")
if not PINECONE_API_KEY:
raise ValueError("PINECONE_API_KEY not set")
# Initialize Pinecone embeddings (same as ingestion script)
print(" Loading Pinecone Embeddings (multilingual-e5-large)...")
embeddings = PineconeEmbeddings(
model="multilingual-e5-large",
pinecone_api_key=PINECONE_API_KEY
)
# Initialize LLM
print(" Loading Gemini LLM...")
llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
google_api_key=GOOGLE_API_KEY,
temperature=0.1,
max_tokens=2048
)
# Connect to Pinecone
print(" Connecting to Pinecone...")
vector_store = PineconeVectorStore(
index_name=PINECONE_INDEX_NAME,
embedding=embeddings,
pinecone_api_key=PINECONE_API_KEY
)
print("✅ LexiBot API Ready!")
yield
# Cleanup
print("👋 Shutting down LexiBot API...")
# Initialize FastAPI
app = FastAPI(
title="LexiBot API",
description="Headless RAG API for Indian Legal Information",
version="2.0.0",
lifespan=lifespan
)
# CORS Middleware - Allow all origins for frontend integration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_session_memory(session_id: str) -> ConversationBufferWindowMemory:
"""
Get or create conversation memory for a session.
Note: Memory is ephemeral (RAM-based). If the container restarts, memory clears.
This is acceptable for the MVP as per PRD requirements.
"""
if session_id not in session_memories:
session_memories[session_id] = ConversationBufferWindowMemory(
k=5, # Keep last 5 exchanges
memory_key="chat_history",
return_messages=True
)
return session_memories[session_id]
def extract_sources(source_documents) -> list[str]:
"""Extract unique act names from source documents."""
sources = set()
for doc in source_documents:
if "act_name" in doc.metadata:
sources.add(doc.metadata["act_name"])
elif "source" in doc.metadata:
sources.add(doc.metadata["source"])
return list(sources)
@app.get("/", response_model=HealthResponse)
async def root():
"""Root endpoint - health check."""
return HealthResponse(status="ok")
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint for uptime monitoring."""
return HealthResponse(status="ok")
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""
Main chat endpoint for legal queries.
- **message**: The user's legal question
- **session_id**: Unique session identifier for conversation memory
Returns:
- **response**: The AI-generated legal response
- **sources**: List of legal acts referenced in the response
"""
if not vector_store or not llm:
raise HTTPException(status_code=503, detail="Service not initialized")
if not request.message.strip():
raise HTTPException(status_code=400, detail="Message cannot be empty")
try:
# Get session memory
memory = get_session_memory(request.session_id)
# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 5}),
return_source_documents=True,
chain_type_kwargs={"prompt": LEGAL_PROMPT}
)
# Execute query
result = qa_chain.invoke({"query": request.message})
# Extract response and sources
response_text = result.get("result", "I couldn't process your query.")
source_docs = result.get("source_documents", [])
sources = extract_sources(source_docs)
# Save to memory
memory.save_context(
{"input": request.message},
{"output": response_text}
)
return ChatResponse(
response=response_text,
sources=sources
)
except Exception as e:
print(f"Error processing chat: {e}")
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|