from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from qdrant_client import QdrantClient from sentence_transformers import SentenceTransformer, CrossEncoder from openai import OpenAI import os from dotenv import load_dotenv import torch import aiohttp load_dotenv() app = FastAPI(title="Polish Law Search API") FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:3000") # CORS dla Next.js app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Konfiguracja --- COLLECTION_NAME = "polish_law_e5" MODEL_NAME = "intfloat/multilingual-e5-large" RERANKER_MODEL = "sdadas/polish-reranker-large-ranknet" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # --- Modele (ładowane przy starcie) --- embedder = None reranker = None qdrant = None openai_client = None @app.on_event("startup") async def load_models(): global embedder, reranker, qdrant, openai_client print(f"🔧 Ładowanie modeli na {DEVICE}...") embedder = SentenceTransformer(MODEL_NAME, device=DEVICE) embedder.max_seq_length = 512 reranker = CrossEncoder(RERANKER_MODEL, max_length=512, device=DEVICE) qdrant = QdrantClient( url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"), timeout=120 ) openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) print("✅ Modele załadowane!") # --- Modele Request/Response --- class SearchRequest(BaseModel): query: str num_results: int = 5 use_reranking: bool = True class SearchResult(BaseModel): origin_id: int signature: str judgment_date: str court_type: str court_name: str judgment_type: str keywords: list[str] judges: list[str] matched_chunk: str score: float class SearchResponse(BaseModel): original_query: str optimized_query: str results: list[SearchResult] # --- Funkcje pomocnicze --- QUERY_REWRITE_PROMPT = """Jesteś ekspertem od polskiego prawa. Przekształć zapytanie na optymalną frazę do wyszukiwania semantycznego. Zasady: usuń zbędne słowa, dodaj synonimy prawnicze, dodaj artykuły kodeksu jeśli znasz. Odpowiedz TYLKO zoptymalizowaną frazą (5-20 słów). Zapytanie: {query} Zoptymalizowana fraza:""" def rewrite_query(query: str) -> str: try: response = openai_client.chat.completions.create( model="gpt-4o-mini", messages=[{"role": "user", "content": QUERY_REWRITE_PROMPT.format(query=query)}], max_tokens=100, temperature=0.3 ) return response.choices[0].message.content.strip() except: return query def rerank_results(query: str, results, top_k: int): if not results: return [] pairs = [(query, r.payload['page_content']) for r in results] scores = reranker.predict(pairs) reranked = sorted(zip(results, scores), key=lambda x: x[1], reverse=True) return [(r, float(s)) for r, s in reranked[:top_k]] # --- Endpointy --- @app.get("/health") def health_check(): return {"status": "ok", "device": DEVICE} @app.post("/search", response_model=SearchResponse) def search(request: SearchRequest): # 1. Przepisz zapytanie optimized = rewrite_query(request.query) # 2. Generuj embedding query_vector = embedder.encode( f"query: {optimized}", normalize_embeddings=True ).tolist() # 3. Szukaj w Qdrant fetch_limit = request.num_results * 5 if request.use_reranking else request.num_results search_results = qdrant.query_points( collection_name=COLLECTION_NAME, query=query_vector, limit=fetch_limit, with_payload=True ).points # 4. Re-ranking if request.use_reranking and search_results: scored_results = rerank_results(optimized, search_results, request.num_results) else: scored_results = [(r, float(r.score)) for r in search_results[:request.num_results]] # 5. Formatuj wyniki results = [] for result, score in scored_results: p = result.payload results.append(SearchResult( origin_id=p.get("origin_id", 0), signature=p.get("signature", ""), judgment_date=p.get("judgment_date", ""), court_type=p.get("court_type", ""), court_name=p.get("court_name", ""), judgment_type=p.get("judgment_type", ""), keywords=p.get("keywords", []), judges=p.get("judges", []), matched_chunk=p.get("page_content", ""), score=score )) return SearchResponse( original_query=request.query, optimized_query=optimized, results=results ) @app.get("/judgment/{judgment_id}") async def get_full_judgment(judgment_id: int): """Pobiera pełny tekst orzeczenia z SAOS API.""" import re url = f"https://www.saos.org.pl/api/judgments/{judgment_id}" async with aiohttp.ClientSession() as session: async with session.get(url) as response: if response.status == 200: data = await response.json() judgment_data = data.get("data", {}) # Wyczyść HTML raw_text = judgment_data.get("textContent", "") clean_text = re.sub(r'<.*?>', ' ', raw_text) clean_text = " ".join(clean_text.split()) court_cases = judgment_data.get("courtCases", []) signature = court_cases[0].get("caseNumber", "") if court_cases else "" return { "id": judgment_id, "signature": signature, "text": clean_text, "judgment_date": judgment_data.get("judgmentDate", ""), "court_type": judgment_data.get("courtType", ""), } raise HTTPException(status_code=404, detail="Judgment not found") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)