File size: 11,934 Bytes
1a4ceb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999e6d4
 
a23971e
1a4ceb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40102e3
 
1a4ceb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e989c18
1a4ceb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5090f7f
1a4ceb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e989c18
1a4ceb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0145b4c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import os
import uuid
import logging
from datetime import datetime, timedelta
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException, Request, Depends, Response, Cookie
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles  # βœ… ADD THIS

from pydantic import BaseModel
from pydantic_settings import BaseSettings
from dotenv import load_dotenv

from upstash_redis.asyncio import Redis
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from slowapi.middleware import SlowAPIMiddleware

from openai import OpenAI
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOpenAI
from langchain_classic.chains import LLMChain
from langchain_core.prompts import PromptTemplate

# ─── SETTINGS ────────────────────────────────────────────────────────────────────
class Settings(BaseSettings):
    OPENAI_API_KEY: str
    UPSTASH_REDIS_REST_URL: str
    UPSTASH_REDIS_REST_TOKEN: str
    VECTOR_DB_PATH: str = "./chroma_db"
    TOP_K: int = 5
    SESSION_TIMEOUT_MIN: int = 30
    RATE_LIMIT: str = "60/minute"

    class Config:
        env_file = ".env"
        extra = "ignore"

settings = Settings()
load_dotenv()

# ─── LOGGING ─────────────────────────────────────────────────────────────────────
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s %(name)s %(message)s'
)
logger = logging.getLogger("legal-bot")

# ─── LIFESPAN MANAGEMENT ─────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
    global redis
    redis = Redis(
        url=settings.UPSTASH_REDIS_REST_URL,
        token=settings.UPSTASH_REDIS_REST_TOKEN
    )
    logger.info("Upstash Redis connection established")
    yield
    await redis.close()
    logger.info("Upstash Redis connection closed")

# ─── FASTAPI APP ────────────────────────────────────────────────────────────────
app = FastAPI(
    title="Irish Legal AI Bot",
    description="RAG‑driven Irish legal assistant",
    lifespan=lifespan
)

app.mount("/static", StaticFiles(directory="static"), name="static")

# CORS - Updated for Hugging Face Spaces
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
    allow_credentials=True,
)

# Rate limiting
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_middleware(SlowAPIMiddleware)

# ─── SECURITY & MODERATION ───────────────────────────────────────────────────────
openai_client = OpenAI(api_key=settings.OPENAI_API_KEY)

async def moderate_content(text: str) -> bool:
    try:
        resp = openai_client.moderations.create(input=text)
        return not resp.results[0].flagged
    except Exception as e:
        logger.error(f"Moderation error: {e}")
        return False

# ─── SESSION MANAGEMENT ──────────────────────────────────────────────────────────
class SessionData(BaseModel):
    session_id: str
    created_at: datetime
    expires_at: datetime  # New field for fixed expiration time
    last_activity: datetime
    history: list

async def get_session(session_id: str = Cookie(default=None), response: Response = None) -> SessionData:
    if session_id:
        raw = await redis.get(session_id)
        if raw:
            data = SessionData.parse_raw(raw)
            # Check if session has expired
            if datetime.utcnow() > data.expires_at:
                await redis.delete(session_id)
            else:
                # Update last activity without changing expiration
                data.last_activity = datetime.utcnow()
                # Save without resetting TTL
                remaining_seconds = (data.expires_at - datetime.utcnow()).total_seconds()
                await redis.setex(session_id, int(remaining_seconds), data.json())
                return data
    
    # Create new session with fixed expiration
    new_id = str(uuid.uuid4())
    created_at = datetime.utcnow()
    expires_at = created_at + timedelta(minutes=settings.SESSION_TIMEOUT_MIN)
    data = SessionData(
        session_id=new_id,
        created_at=created_at,
        expires_at=expires_at,
        last_activity=created_at,
        history=[]
    )
    await redis.setex(
        new_id,
        settings.SESSION_TIMEOUT_MIN * 60,
        data.json()
    )
    response.set_cookie(
        key="session_id", 
        value=new_id, 
        httponly=True, 
        secure=True,
        samesite="None",
        path="/"
    )
    return data

# ─── VECTOR & LLM SETUP ─────────────────────────────────────────────────────────
embeddings = OpenAIEmbeddings(openai_api_key=settings.OPENAI_API_KEY)
vectordb = Chroma(embedding_function=embeddings, persist_directory=settings.VECTOR_DB_PATH)

LEGAL_PROMPT = PromptTemplate(
    input_variables=["context","question","history"],
    template=(
        "As an Irish legal expert, provide a precise, concise answer using ONLY the context below."
        "\n1. Direct answer (1-2 sentences)\n2. Key legal basis (cite sources)\n3. Practical implications"
        "\n\nContext:\n{context}\n\nHistory:\n{history}\n\nQuestion: {question}\n\nAnswer:")
)

POLISH_PROMPT = PromptTemplate(
    input_variables=["raw_answer","question"],
    template=(
        "Enhance this Irish legal answer with current figures/fines (2024), recent amendments, and practical next steps."
        " Keep response under 150 words.\n\nOriginal:\n{raw_answer}\n\nQuestion: {question}\n\nEnhanced Answer:")
)

legal_chain = LLMChain(
    llm=ChatOpenAI(temperature=0, openai_api_key=settings.OPENAI_API_KEY, model="gpt-4-turbo"), 
    prompt=LEGAL_PROMPT
)

polish_chain = LLMChain(
    llm=ChatOpenAI(temperature=0.3, openai_api_key=settings.OPENAI_API_KEY, model="gpt-4-turbo"), 
    prompt=POLISH_PROMPT
)

# ─── HELPERS ───────────────────────────────────────────────────────────────────
def retrieve_context(query: str):
    docs = vectordb.similarity_search_with_score(query, k=settings.TOP_K)
    snippets = [f"[Source {i+1} | Relevance: {score:.2f}] {doc.page_content.strip()}" for i,(doc,score) in enumerate(docs)]
    sources = [f"Source {i+1}" for i in range(len(docs))]
    return "\n\n".join(snippets), sources

# ─── MODELS ─────────────────────────────────────────────────────────────────────
class QueryRequest(BaseModel):
    query: str

class QueryResponse(BaseModel):
    answer: str
    session_id: str
    sources: list

class SessionStatusResponse(BaseModel):
    status: str  # "active", "expired", or "new"
    ttl: int     # seconds until expiration (-2 = expired, -1 = no expiration)
    session_id: str | None
    created_at: datetime | None
    expires_at: datetime | None  # New field
    last_activity: datetime | None
    history_count: int | None

class SessionHistoryResponse(BaseModel):
    history: list
    session_id: str

# ─── ROUTES ─────────────────────────────────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
async def root():
    return FileResponse("index.html")

@app.post("/query", response_model=QueryResponse)
@limiter.limit(settings.RATE_LIMIT)
async def handle_query(
    request: Request,    
    req: QueryRequest,
    session: SessionData = Depends(get_session),
    response: Response = None
):
    if not await moderate_content(req.query):
        raise HTTPException(400, "Content policy violation")

    context, sources = retrieve_context(req.query)
    history = session.history[-3:] if session.history else []

    raw = legal_chain.run({"context": context, "question": req.query, "history": history})
    polished = polish_chain.run({"raw_answer": raw, "question": req.query})
    if not await moderate_content(polished):
        polished = "Restricted content."

    # Update session without changing expiration
    session.history.append({"q": req.query, "a": polished, "timestamp": datetime.utcnow().isoformat()})
    if len(session.history) > 5:
        session.history.pop(0)
    
    # Save with original expiration
    remaining_seconds = (session.expires_at - datetime.utcnow()).total_seconds()
    await redis.setex(
        session.session_id,
        int(remaining_seconds),
        session.json()
    )

    return QueryResponse(answer=polished, session_id=session.session_id, sources=sources)

@app.get("/session/status", response_model=SessionStatusResponse)
async def get_session_status(session_id: str = Cookie(default=None)):
    if not session_id:
        return SessionStatusResponse(
            status="new",
            ttl=-2,
            session_id=None,
            created_at=None,
            expires_at=None,
            last_activity=None,
            history_count=None
        )
    
    raw = await redis.get(session_id)
    if not raw:
        return SessionStatusResponse(
            status="expired",
            ttl=-2,
            session_id=session_id,
            created_at=None,
            expires_at=None,
            last_activity=None,
            history_count=None
        )
    
    data = SessionData.parse_raw(raw)
    now = datetime.utcnow()
    
    if now > data.expires_at:
        return SessionStatusResponse(
            status="expired",
            ttl=-2,
            session_id=session_id,
            created_at=data.created_at,
            expires_at=data.expires_at,
            last_activity=data.last_activity,
            history_count=len(data.history)
        )
    
    ttl = int((data.expires_at - now).total_seconds())
    return SessionStatusResponse(
        status="active",
        ttl=ttl,
        session_id=session_id,
        created_at=data.created_at,
        expires_at=data.expires_at,
        last_activity=data.last_activity,
        history_count=len(data.history)
    )

@app.get("/session/history", response_model=SessionHistoryResponse)
async def get_session_history(session: SessionData = Depends(get_session)):
    return {
        "history": session.history,
        "session_id": session.session_id
    }

# ─── SERVER LAUNCH ──────────────────────────────────────────────────────────────
if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run("app:app", host="0.0.0.0", port=port, workers=4, log_level="info")