Spaces:
Sleeping
Sleeping
| # ============================================ | |
| # file: app/main.py | |
| # ============================================ | |
| from __future__ import annotations | |
| from typing import List | |
| import anyio | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from .deps import get_chain, reload_chain | |
| from .history import get_history, add_to_history, clear_history | |
| from .schemas import AskRequest, AskResponse, SourceDoc | |
| from .utils import convert_to_eastern_arabic | |
| app = FastAPI(title="Legal RAG API", version="1.0.0") | |
| # Optional: allow frontend calls | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # tighten later | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def _startup(): | |
| # preload once | |
| get_chain() | |
| def health(): | |
| return {"status": "ok"} | |
| def reload(): | |
| reload_chain() | |
| return {"status": "reloaded"} | |
| def clear_session(session_id: str = "default"): | |
| """Clear conversation history for a session.""" | |
| clear_history(session_id) | |
| return {"status": "cleared", "session_id": session_id} | |
| def get_session_history(session_id: str = "default"): | |
| """Retrieve conversation history for a session.""" | |
| history = get_history(session_id) | |
| return { | |
| "session_id": session_id, | |
| "messages": [{"role": msg.role, "content": msg.content} for msg in history] | |
| } | |
| def _dedupe_sources(docs) -> List[SourceDoc]: | |
| if not docs: | |
| return [] | |
| seen = set() | |
| out: List[SourceDoc] = [] | |
| for doc in docs: | |
| article_num = str(doc.metadata.get("article_number", "")).strip() | |
| if article_num and article_num in seen: | |
| continue | |
| if article_num: | |
| seen.add(article_num) | |
| out.append( | |
| SourceDoc( | |
| article_id=str(doc.metadata.get("article_id", "")) or None, | |
| article_number=article_num or None, | |
| law_name=str(doc.metadata.get("law_name", "")) or None, | |
| legal_nature=str(doc.metadata.get("legal_nature", "")) or None, | |
| keywords=str(doc.metadata.get("keywords", "")) or None, | |
| part=str(doc.metadata.get("part", "")) or None, | |
| chapter=str(doc.metadata.get("chapter", "")) or None, | |
| page_content=str(doc.page_content or ""), | |
| ) | |
| ) | |
| return out | |
| async def ask(payload: AskRequest): | |
| # Retrieve conversation history for this session | |
| history = get_history(payload.session_id) | |
| history_dicts = [{"role": msg.role, "content": msg.content} for msg in history] | |
| # Get chain with conversation history context | |
| chain = get_chain(conversation_history=history_dicts) | |
| try: | |
| # LangChain invoke is sync; run in worker thread | |
| result = await anyio.to_thread.run_sync(chain.invoke, payload.query) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| answer = result.get("answer", "") | |
| sources_docs = result.get("context", []) if payload.include_sources else [] | |
| sources = _dedupe_sources(sources_docs) | |
| if payload.eastern_arabic_numerals: | |
| answer = convert_to_eastern_arabic(answer) | |
| if payload.include_sources: | |
| for s in sources: | |
| s.page_content = convert_to_eastern_arabic(s.page_content) | |
| if s.article_number: | |
| s.article_number = convert_to_eastern_arabic(s.article_number) | |
| # Save this exchange to history | |
| add_to_history(payload.session_id, payload.query, answer) | |
| return AskResponse(answer=answer, sources=sources, session_id=payload.session_id, raw=result) | |