Files changed (1) hide show
  1. app.py +261 -261
app.py CHANGED
@@ -1,261 +1,261 @@
1
- import os
2
- import uuid
3
- import logging
4
- from datetime import datetime, timedelta
5
- from contextlib import asynccontextmanager
6
-
7
- from fastapi import FastAPI, HTTPException, Request, Depends, Response, Cookie
8
- from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from pydantic import BaseModel
11
- from pydantic_settings import BaseSettings
12
- from dotenv import load_dotenv
13
-
14
- from upstash_redis.asyncio import Redis
15
- from slowapi import Limiter
16
- from slowapi.errors import RateLimitExceeded
17
- from slowapi.util import get_remote_address
18
- from slowapi.middleware import SlowAPIMiddleware
19
-
20
- from openai import OpenAI
21
- from langchain_community.embeddings import OpenAIEmbeddings
22
- from langchain_community.vectorstores import Chroma
23
- from langchain_community.chat_models import ChatOpenAI
24
- from langchain.chains import LLMChain
25
- from langchain.prompts import PromptTemplate
26
-
27
- # ─── SETTINGS ────────────────────────────────────────────────────────────────────
28
- class Settings(BaseSettings):
29
- OPENAI_API_KEY: str
30
- UPSTASH_REDIS_REST_URL: str
31
- UPSTASH_REDIS_REST_TOKEN: str
32
- VECTOR_DB_PATH: str = "./chroma_db"
33
- TOP_K: int = 5
34
- SESSION_TIMEOUT_MIN: int = 30
35
- RATE_LIMIT: str = "60/minute"
36
-
37
- class Config:
38
- env_file = ".env"
39
- extra = "ignore" # Add this line to ignore extra variables
40
-
41
- settings = Settings()
42
- load_dotenv()
43
-
44
- # ─── LOGGING ─────────────────────────────────────────────────────────────────────
45
- logging.basicConfig(
46
- level=logging.INFO,
47
- format='%(asctime)s %(levelname)s %(name)s %(message)s'
48
- )
49
- logger = logging.getLogger("legal-bot")
50
-
51
- # ─── LIFESPAN MANAGEMENT ─────────────────────────────────────────────────────────
52
- @asynccontextmanager
53
- async def lifespan(app: FastAPI):
54
- global redis
55
- redis = Redis(
56
- url=settings.UPSTASH_REDIS_REST_URL,
57
- token=settings.UPSTASH_REDIS_REST_TOKEN
58
- )
59
- logger.info("Upstash Redis connection established")
60
- yield
61
- await redis.close()
62
- logger.info("Upstash Redis connection closed")
63
-
64
- # ─── FASTAPI APP ────────────────────────────────────────────────────────────────
65
- app = FastAPI(
66
- title="Irish Legal AI Bot",
67
- description="RAG‑driven Irish legal assistant",
68
- lifespan=lifespan
69
- )
70
-
71
- # CORS
72
- app.add_middleware(
73
- CORSMiddleware,
74
- allow_origins=["http://localhost:8000"],
75
- allow_methods=["GET", "POST"],
76
- allow_headers=["*"],
77
- allow_credentials=True,
78
- )
79
-
80
- # Rate limiting
81
- limiter = Limiter(key_func=get_remote_address)
82
- app.state.limiter = limiter
83
- app.add_middleware(SlowAPIMiddleware)
84
-
85
- # ─── SECURITY & MODERATION ───────────────────────────────────────────────────────
86
- openai_client = OpenAI(api_key=settings.OPENAI_API_KEY)
87
-
88
- async def moderate_content(text: str) -> bool:
89
- try:
90
- resp = openai_client.moderations.create(input=text)
91
- return not resp.results[0].flagged
92
- except Exception as e:
93
- logger.error(f"Moderation error: {e}")
94
- return False
95
-
96
- # ─── SESSION MANAGEMENT ──────────────────────────────────────────────────────────
97
- class SessionData(BaseModel):
98
- session_id: str
99
- created_at: datetime
100
- last_activity: datetime
101
- history: list
102
-
103
- async def get_session(session_id: str = Cookie(default=None), response: Response = None) -> SessionData:
104
- if session_id:
105
- raw = await redis.get(session_id)
106
- if raw:
107
- data = SessionData.parse_raw(raw)
108
- # Update last activity
109
- data.last_activity = datetime.utcnow()
110
- await redis.setex(
111
- session_id,
112
- settings.SESSION_TIMEOUT_MIN * 60,
113
- data.json()
114
- )
115
- return data
116
-
117
- # Create new session
118
- new_id = str(uuid.uuid4())
119
- data = SessionData(
120
- session_id=new_id,
121
- created_at=datetime.utcnow(),
122
- last_activity=datetime.utcnow(),
123
- history=[]
124
- )
125
- await redis.setex(
126
- new_id,
127
- settings.SESSION_TIMEOUT_MIN * 60,
128
- data.json()
129
- )
130
- response.set_cookie(key="session_id", value=new_id, httponly=True, secure=True)
131
- return data
132
-
133
- # ─── VECTOR & LLM SETUP ─────────────────────────────────────────────────────────
134
- embeddings = OpenAIEmbeddings(openai_api_key=settings.OPENAI_API_KEY)
135
- vectordb = Chroma(embedding_function=embeddings, persist_directory=settings.VECTOR_DB_PATH)
136
- LEGAL_PROMPT = PromptTemplate(
137
- input_variables=["context","question","history"],
138
- template=(
139
- "As an Irish legal expert, provide a precise, concise answer using ONLY the context below."
140
- "\n1. Direct answer (1-2 sentences)\n2. Key legal basis (cite sources)\n3. Practical implications"
141
- "\n\nContext:\n{context}\n\nHistory:\n{history}\n\nQuestion: {question}\n\nAnswer:" )
142
- )
143
- POLISH_PROMPT = PromptTemplate(
144
- input_variables=["raw_answer","question"],
145
- template=(
146
- "Enhance this Irish legal answer with current figures/fines (2024), recent amendments, and practical next steps."
147
- " Keep response under 150 words.\n\nOriginal:\n{raw_answer}\n\nQuestion: {question}\n\nEnhanced Answer:" )
148
- )
149
- legal_chain = LLMChain(llm=ChatOpenAI(temperature=0, openai_api_key=settings.OPENAI_API_KEY, model="gpt-4-turbo"), prompt=LEGAL_PROMPT)
150
- polish_chain = LLMChain(llm=ChatOpenAI(temperature=0.3, openai_api_key=settings.OPENAI_API_KEY, model="gpt-4-turbo"), prompt=POLISH_PROMPT)
151
-
152
- # ─── HELPERS ───────────────────────────────────────────────────────────────────
153
- def retrieve_context(query: str):
154
- docs = vectordb.similarity_search_with_score(query, k=settings.TOP_K)
155
- snippets = [f"[Source {i+1} | Relevance: {score:.2f}] {doc.page_content.strip()}" for i,(doc,score) in enumerate(docs)]
156
- sources = [f"Source {i+1}" for i in range(len(docs))]
157
- return "\n\n".join(snippets), sources
158
-
159
- # ─── MODELS ─────────────────────────────────────────────────────────────────────
160
- class QueryRequest(BaseModel):
161
- query: str
162
-
163
- class QueryResponse(BaseModel):
164
- answer: str
165
- session_id: str
166
- sources: list
167
-
168
- class SessionStatusResponse(BaseModel):
169
- status: str # "active", "expired", or "new"
170
- ttl: int # seconds until expiration (-2 = expired, -1 = no expiration)
171
- session_id: str | None
172
- created_at: datetime | None
173
- last_activity: datetime | None
174
- history_count: int | None
175
-
176
- # ─── ROUTES ─────────────────────────────────────────────────────────────────────
177
- @app.get("/", response_class=HTMLResponse)
178
- async def root():
179
- return FileResponse("frontend/index.html")
180
-
181
- @app.post("/query", response_model=QueryResponse)
182
- @limiter.limit(settings.RATE_LIMIT)
183
- async def handle_query(
184
- request: Request,
185
- req: QueryRequest,
186
- session: SessionData = Depends(get_session),
187
- response: Response = None
188
- ):
189
- if not await moderate_content(req.query):
190
- raise HTTPException(400, "Content policy violation")
191
-
192
- context, sources = retrieve_context(req.query)
193
- history = session.history[-3:] if session.history else []
194
-
195
- raw = legal_chain.run({"context": context, "question": req.query, "history": history})
196
- polished = polish_chain.run({"raw_answer": raw, "question": req.query})
197
- if not await moderate_content(polished):
198
- polished = "Restricted content."
199
-
200
- # Update session
201
- session.history.append({"q": req.query, "a": polished, "timestamp": datetime.utcnow().isoformat()})
202
- if len(session.history) > 5:
203
- session.history.pop(0)
204
-
205
- # Save with TTL refresh
206
- await redis.setex(
207
- session.session_id,
208
- settings.SESSION_TIMEOUT_MIN * 60,
209
- session.json()
210
- )
211
-
212
- return QueryResponse(answer=polished, session_id=session.session_id, sources=sources)
213
-
214
- @app.get("/session/status", response_model=SessionStatusResponse)
215
- async def get_session_status(session_id: str = Cookie(default=None)):
216
- if not session_id:
217
- return SessionStatusResponse(
218
- status="new",
219
- ttl=-2,
220
- session_id=None,
221
- created_at=None,
222
- last_activity=None,
223
- history_count=None
224
- )
225
-
226
- ttl = await redis.ttl(session_id)
227
- if ttl < 0: # Key doesn't exist or has no TTL
228
- return SessionStatusResponse(
229
- status="expired",
230
- ttl=-2,
231
- session_id=session_id,
232
- created_at=None,
233
- last_activity=None,
234
- history_count=None
235
- )
236
-
237
- raw = await redis.get(session_id)
238
- if not raw:
239
- return SessionStatusResponse(
240
- status="expired",
241
- ttl=-2,
242
- session_id=session_id,
243
- created_at=None,
244
- last_activity=None,
245
- history_count=None
246
- )
247
-
248
- data = SessionData.parse_raw(raw)
249
- return SessionStatusResponse(
250
- status="active",
251
- ttl=ttl,
252
- session_id=session_id,
253
- created_at=data.created_at,
254
- last_activity=data.last_activity,
255
- history_count=len(data.history)
256
- )
257
-
258
- # ─── SERVER LAUNCH ──────────────────────────────────────────────────────────────
259
- if __name__ == "__main__":
260
- import uvicorn
261
- uvicorn.run("app:app", host="0.0.0.0", port=8000, workers=4, log_level="info")
 
1
+ import os
2
+ import uuid
3
+ import logging
4
+ from datetime import datetime, timedelta
5
+ from contextlib import asynccontextmanager
6
+
7
+ from fastapi import FastAPI, HTTPException, Request, Depends, Response, Cookie
8
+ from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
+ from pydantic_settings import BaseSettings
12
+ from dotenv import load_dotenv
13
+
14
+ from upstash_redis.asyncio import Redis
15
+ from slowapi import Limiter
16
+ from slowapi.errors import RateLimitExceeded
17
+ from slowapi.util import get_remote_address
18
+ from slowapi.middleware import SlowAPIMiddleware
19
+
20
+ from openai import OpenAI
21
+ from langchain_community.embeddings import OpenAIEmbeddings
22
+ from langchain_community.vectorstores import Chroma
23
+ from langchain_community.chat_models import ChatOpenAI
24
+ from langchain.chains import LLMChain
25
+ from langchain.prompts import PromptTemplate
26
+
27
+ # ─── SETTINGS ────────────────────────────────────────────────────────────────────
28
+ class Settings(BaseSettings):
29
+ OPENAI_API_KEY: str
30
+ UPSTASH_REDIS_REST_URL: str
31
+ UPSTASH_REDIS_REST_TOKEN: str
32
+ VECTOR_DB_PATH: str = "./chroma_db"
33
+ TOP_K: int = 5
34
+ SESSION_TIMEOUT_MIN: int = 30
35
+ RATE_LIMIT: str = "60/minute"
36
+
37
+ class Config:
38
+ env_file = ".env"
39
+ extra = "ignore" # Add this line to ignore extra variables
40
+
41
+ settings = Settings()
42
+ load_dotenv()
43
+
44
+ # ─── LOGGING ─────────────────────────────────────────────────────────────────────
45
+ logging.basicConfig(
46
+ level=logging.INFO,
47
+ format='%(asctime)s %(levelname)s %(name)s %(message)s'
48
+ )
49
+ logger = logging.getLogger("legal-bot")
50
+
51
+ # ─── LIFESPAN MANAGEMENT ─────────────────────────────────────────────────────────
52
+ @asynccontextmanager
53
+ async def lifespan(app: FastAPI):
54
+ global redis
55
+ redis = Redis(
56
+ url=settings.UPSTASH_REDIS_REST_URL,
57
+ token=settings.UPSTASH_REDIS_REST_TOKEN
58
+ )
59
+ logger.info("Upstash Redis connection established")
60
+ yield
61
+ await redis.close()
62
+ logger.info("Upstash Redis connection closed")
63
+
64
+ # ─── FASTAPI APP ────────────────────────────────────────────────────────────────
65
+ app = FastAPI(
66
+ title="Irish Legal AI Bot",
67
+ description="RAG‑driven Irish legal assistant",
68
+ lifespan=lifespan
69
+ )
70
+
71
+ # CORS
72
+ app.add_middleware(
73
+ CORSMiddleware,
74
+ allow_origins=["http://localhost:8000"],
75
+ allow_methods=["GET", "POST"],
76
+ allow_headers=["*"],
77
+ allow_credentials=True,
78
+ )
79
+
80
+ # Rate limiting
81
+ limiter = Limiter(key_func=get_remote_address)
82
+ app.state.limiter = limiter
83
+ app.add_middleware(SlowAPIMiddleware)
84
+
85
+ # ─── SECURITY & MODERATION ───────────────────────────────────────────────────────
86
+ openai_client = OpenAI(api_key=settings.OPENAI_API_KEY)
87
+
88
+ async def moderate_content(text: str) -> bool:
89
+ try:
90
+ resp = openai_client.moderations.create(input=text)
91
+ return not resp.results[0].flagged
92
+ except Exception as e:
93
+ logger.error(f"Moderation error: {e}")
94
+ return False
95
+
96
+ # ─── SESSION MANAGEMENT ──────────────────────────────────────────────────────────
97
+ class SessionData(BaseModel):
98
+ session_id: str
99
+ created_at: datetime
100
+ last_activity: datetime
101
+ history: list
102
+
103
+ async def get_session(session_id: str = Cookie(default=None), response: Response = None) -> SessionData:
104
+ if session_id:
105
+ raw = await redis.get(session_id)
106
+ if raw:
107
+ data = SessionData.parse_raw(raw)
108
+ # Update last activity
109
+ data.last_activity = datetime.utcnow()
110
+ await redis.setex(
111
+ session_id,
112
+ settings.SESSION_TIMEOUT_MIN * 60,
113
+ data.json()
114
+ )
115
+ return data
116
+
117
+ # Create new session
118
+ new_id = str(uuid.uuid4())
119
+ data = SessionData(
120
+ session_id=new_id,
121
+ created_at=datetime.utcnow(),
122
+ last_activity=datetime.utcnow(),
123
+ history=[]
124
+ )
125
+ await redis.setex(
126
+ new_id,
127
+ settings.SESSION_TIMEOUT_MIN * 60,
128
+ data.json()
129
+ )
130
+ response.set_cookie(key="session_id", value=new_id, httponly=True, secure=True)
131
+ return data
132
+
133
+ # ─── VECTOR & LLM SETUP ─────────────────────────────────────────────────────────
134
+ embeddings = OpenAIEmbeddings(openai_api_key=settings.OPENAI_API_KEY)
135
+ vectordb = Chroma(embedding_function=embeddings, persist_directory=settings.VECTOR_DB_PATH)
136
+ LEGAL_PROMPT = PromptTemplate(
137
+ input_variables=["context","question","history"],
138
+ template=(
139
+ "As an Irish legal expert, provide a precise, concise answer using ONLY the context below."
140
+ "\n1. Direct answer (1-2 sentences)\n2. Key legal basis (cite sources)\n3. Practical implications"
141
+ "\n\nContext:\n{context}\n\nHistory:\n{history}\n\nQuestion: {question}\n\nAnswer:" )
142
+ )
143
+ POLISH_PROMPT = PromptTemplate(
144
+ input_variables=["raw_answer","question"],
145
+ template=(
146
+ "Enhance this Irish legal answer with current figures/fines (2024), recent amendments, and practical next steps."
147
+ " Keep response under 150 words.\n\nOriginal:\n{raw_answer}\n\nQuestion: {question}\n\nEnhanced Answer:" )
148
+ )
149
+ legal_chain = LLMChain(llm=ChatOpenAI(temperature=0, openai_api_key=settings.OPENAI_API_KEY, model="gpt-4-turbo"), prompt=LEGAL_PROMPT)
150
+ polish_chain = LLMChain(llm=ChatOpenAI(temperature=0.3, openai_api_key=settings.OPENAI_API_KEY, model="gpt-4-turbo"), prompt=POLISH_PROMPT)
151
+
152
+ # ─── HELPERS ───────────────────────────────────────────────────────────────────
153
+ def retrieve_context(query: str):
154
+ docs = vectordb.similarity_search_with_score(query, k=settings.TOP_K)
155
+ snippets = [f"[Source {i+1} | Relevance: {score:.2f}] {doc.page_content.strip()}" for i,(doc,score) in enumerate(docs)]
156
+ sources = [f"Source {i+1}" for i in range(len(docs))]
157
+ return "\n\n".join(snippets), sources
158
+
159
+ # ─── MODELS ─────────────────────────────────────────────────────────────────────
160
+ class QueryRequest(BaseModel):
161
+ query: str
162
+
163
+ class QueryResponse(BaseModel):
164
+ answer: str
165
+ session_id: str
166
+ sources: list
167
+
168
+ class SessionStatusResponse(BaseModel):
169
+ status: str # "active", "expired", or "new"
170
+ ttl: int # seconds until expiration (-2 = expired, -1 = no expiration)
171
+ session_id: str | None
172
+ created_at: datetime | None
173
+ last_activity: datetime | None
174
+ history_count: int | None
175
+
176
+ # ─── ROUTES ─────────────────────────────────────────────────────────────────────
177
+ @app.get("/", response_class=HTMLResponse)
178
+ async def root():
179
+ return FileResponse("frontend/index.html")
180
+
181
+ @app.post("/query", response_model=QueryResponse)
182
+ @limiter.limit(settings.RATE_LIMIT)
183
+ async def handle_query(
184
+ request: Request,
185
+ req: QueryRequest,
186
+ session: SessionData = Depends(get_session),
187
+ response: Response = None
188
+ ):
189
+ if not await moderate_content(req.query):
190
+ raise HTTPException(400, "Content policy violation")
191
+
192
+ context, sources = retrieve_context(req.query)
193
+ history = session.history[-3:] if session.history else []
194
+
195
+ raw = legal_chain.run({"context": context, "question": req.query, "history": history})
196
+ polished = polish_chain.run({"raw_answer": raw, "question": req.query})
197
+ if not await moderate_content(polished):
198
+ polished = "Restricted content."
199
+
200
+ # Update session
201
+ session.history.append({"q": req.query, "a": polished, "timestamp": datetime.utcnow().isoformat()})
202
+ if len(session.history) > 5:
203
+ session.history.pop(0)
204
+
205
+ # Save with TTL refresh
206
+ await redis.setex(
207
+ session.session_id,
208
+ settings.SESSION_TIMEOUT_MIN * 60,
209
+ session.json()
210
+ )
211
+
212
+ return QueryResponse(answer=polished, session_id=session.session_id, sources=sources)
213
+
214
+ @app.get("/session/status", response_model=SessionStatusResponse)
215
+ async def get_session_status(session_id: str = Cookie(default=None)):
216
+ if not session_id:
217
+ return SessionStatusResponse(
218
+ status="new",
219
+ ttl=-2,
220
+ session_id=None,
221
+ created_at=None,
222
+ last_activity=None,
223
+ history_count=None
224
+ )
225
+
226
+ ttl = await redis.ttl(session_id)
227
+ if ttl < 0: # Key doesn't exist or has no TTL
228
+ return SessionStatusResponse(
229
+ status="expired",
230
+ ttl=-2,
231
+ session_id=session_id,
232
+ created_at=None,
233
+ last_activity=None,
234
+ history_count=None
235
+ )
236
+
237
+ raw = await redis.get(session_id)
238
+ if not raw:
239
+ return SessionStatusResponse(
240
+ status="expired",
241
+ ttl=-2,
242
+ session_id=session_id,
243
+ created_at=None,
244
+ last_activity=None,
245
+ history_count=None
246
+ )
247
+
248
+ data = SessionData.parse_raw(raw)
249
+ return SessionStatusResponse(
250
+ status="active",
251
+ ttl=ttl,
252
+ session_id=session_id,
253
+ created_at=data.created_at,
254
+ last_activity=data.last_activity,
255
+ history_count=len(data.history)
256
+ )
257
+
258
+ # ─── SERVER LAUNCH ──────────────────────────────────────────────────────────────
259
+ if __name__ == "__main__":
260
+ import uvicorn
261
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")