Spaces:
Sleeping
Sleeping
| # app/api.py | |
| from __future__ import annotations | |
| import os | |
| import re | |
| from collections import deque | |
| from datetime import datetime, timezone | |
| from time import perf_counter | |
| from typing import List, Optional, Dict, Any | |
| import faiss | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from pydantic import BaseModel, Field | |
| from .rag_system import SimpleRAG, UPLOAD_DIR, INDEX_DIR | |
| __version__ = "1.3.1" | |
| app = FastAPI(title="RAG API", version=__version__) | |
| # CORS (Streamlit UI üçün) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| rag = SimpleRAG() | |
| # -------------------- Schemas -------------------- | |
| class UploadResponse(BaseModel): | |
| filename: str | |
| chunks_added: int | |
| class AskRequest(BaseModel): | |
| question: str = Field(..., min_length=1) | |
| top_k: int = Field(5, ge=1, le=20) | |
| class AskResponse(BaseModel): | |
| answer: str | |
| contexts: List[str] | |
| class HistoryItem(BaseModel): | |
| question: str | |
| timestamp: str | |
| class HistoryResponse(BaseModel): | |
| total_chunks: int | |
| history: List[HistoryItem] = [] | |
| # -------------------- Stats (in-memory) -------------------- | |
| class StatsStore: | |
| def __init__(self): | |
| self.documents_indexed = 0 | |
| self.questions_answered = 0 | |
| self.latencies_ms = deque(maxlen=500) | |
| self.last7_questions = deque([0] * 7, maxlen=7) # sadə günlük sayğac | |
| self.history = deque(maxlen=50) | |
| def add_docs(self, n: int): | |
| if n > 0: | |
| self.documents_indexed += int(n) | |
| def add_question(self, latency_ms: Optional[int] = None, q: Optional[str] = None): | |
| self.questions_answered += 1 | |
| if latency_ms is not None: | |
| self.latencies_ms.append(int(latency_ms)) | |
| if len(self.last7_questions) == 7: | |
| self.last7_questions[0] += 1 | |
| if q: | |
| self.history.appendleft( | |
| {"question": q, "timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds")} | |
| ) | |
| def avg_ms(self) -> int: | |
| return int(sum(self.latencies_ms) / len(self.latencies_ms)) if self.latencies_ms else 0 | |
| stats = StatsStore() | |
| # -------------------- Helpers -------------------- | |
| _STOPWORDS = { | |
| "the","a","an","of","for","and","or","in","on","to","from","with","by","is","are", | |
| "was","were","be","been","being","at","as","that","this","these","those","it","its", | |
| "into","than","then","so","such","about","over","per","via","vs","within" | |
| } | |
| def _tokenize(s: str) -> List[str]: | |
| return [w for w in re.findall(r"[a-zA-Z0-9]+", s.lower()) if w and w not in _STOPWORDS and len(w) > 2] | |
| def _is_generic_answer(text: str) -> bool: | |
| if not text: | |
| return True | |
| low = text.strip().lower() | |
| if len(low) < 15: | |
| return True | |
| # tipik generik pattern-lər | |
| if "based on document context" in low or "appears to be" in low: | |
| return True | |
| return False | |
| def _extractive_fallback(question: str, contexts: List[str], max_chars: int = 600) -> str: | |
| """ Sualın açar sözlərinə əsasən kontekstdən cümlələr seç. """ | |
| if not contexts: | |
| return "I couldn't find relevant information in the indexed documents for this question." | |
| qtok = set(_tokenize(question)) | |
| if not qtok: | |
| return (contexts[0] or "")[:max_chars] | |
| # cümlələrə böl və skorla | |
| sentences: List[str] = [] | |
| for c in contexts: | |
| for s in re.split(r"(?<=[\.!\?])\s+|\n+", (c or "").strip()): | |
| s = s.strip() | |
| if s: | |
| sentences.append(s) | |
| scored: List[tuple[int, str]] = [] | |
| for s in sentences: | |
| st = set(_tokenize(s)) | |
| scored.append((len(qtok & st), s)) | |
| scored.sort(key=lambda x: (x[0], len(x[1])), reverse=True) | |
| picked: List[str] = [] | |
| for sc, s in scored: | |
| if sc <= 0 and picked: | |
| break | |
| if len((" ".join(picked) + " " + s).strip()) > max_chars: | |
| break | |
| picked.append(s) | |
| if not picked: | |
| return (contexts[0] or "")[:max_chars] | |
| bullets = "\n".join(f"- {p}" for p in picked) | |
| return f"Answer (based on document context):\n{bullets}" | |
| # -------------------- Routes -------------------- | |
| def root(): | |
| return RedirectResponse(url="/docs") | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "version": app.version, | |
| "summarizer": "extractive_en + translate + keyword_fallback", | |
| "faiss_ntotal": int(getattr(rag.index, "ntotal", 0)), | |
| "model_dim": int(getattr(rag.index, "d", rag.embed_dim)), | |
| } | |
| def debug_translate(): | |
| try: | |
| from transformers import pipeline | |
| tr = pipeline("translation", model="Helsinki-NLP/opus-mt-az-en", cache_dir=str(rag.cache_dir), device=-1) | |
| out = tr("Sənəd təmiri və quraşdırılması ilə bağlı işlər görülüb.", max_length=80)[0]["translation_text"] | |
| return {"ok": True, "example_out": out} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"ok": False, "error": str(e)}) | |
| async def upload_pdf(file: UploadFile = File(...)): | |
| if not file.filename.lower().endswith(".pdf"): | |
| raise HTTPException(status_code=400, detail="Only PDF files are allowed.") | |
| dest = UPLOAD_DIR / file.filename | |
| with open(dest, "wb") as f: | |
| while True: | |
| chunk = await file.read(1024 * 1024) | |
| if not chunk: | |
| break | |
| f.write(chunk) | |
| added = rag.add_pdf(dest) | |
| if added == 0: | |
| raise HTTPException(status_code=400, detail="No extractable text found (likely a scanned image PDF).") | |
| stats.add_docs(added) | |
| return UploadResponse(filename=file.filename, chunks_added=added) | |
| def ask_question(payload: AskRequest): | |
| q = (payload.question or "").strip() | |
| if not q: | |
| raise HTTPException(status_code=400, detail="Missing 'question'.") | |
| k = max(1, int(payload.top_k)) | |
| t0 = perf_counter() | |
| # 1) Həmişə sual embedding-i ilə axtar | |
| try: | |
| hits = rag.search(q, k=k) # List[Tuple[text, score]] | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Search failed: {e}") | |
| contexts = [c for c, _ in (hits or []) if c] or (getattr(rag, "last_added", [])[:k] if getattr(rag, "last_added", None) else []) | |
| if not contexts: | |
| latency_ms = int((perf_counter() - t0) * 1000) | |
| stats.add_question(latency_ms, q=q) | |
| return AskResponse( | |
| answer="I couldn't find relevant information in the indexed documents for this question.", | |
| contexts=[] | |
| ) | |
| # 2) Cavabı sintez et (rag içində LLM/rule-based ola bilər) | |
| try: | |
| synthesized = (rag.synthesize_answer(q, contexts) or "").strip() | |
| except Exception: | |
| synthesized = "" | |
| # 3) Generic görünürsə, extractive fallback | |
| if _is_generic_answer(synthesized): | |
| synthesized = _extractive_fallback(q, contexts, max_chars=600) | |
| latency_ms = int((perf_counter() - t0) * 1000) | |
| stats.add_question(latency_ms, q=q) | |
| return AskResponse(answer=synthesized, contexts=contexts) | |
| def get_history(): | |
| return HistoryResponse( | |
| total_chunks=len(rag.chunks), | |
| history=[HistoryItem(**h) for h in list(stats.history)] | |
| ) | |
| def stats_endpoint(): | |
| return { | |
| "documents_indexed": stats.documents_indexed, | |
| "questions_answered": stats.questions_answered, | |
| "avg_ms": stats.avg_ms, | |
| "last7_questions": list(stats.last7_questions), | |
| "total_chunks": len(rag.chunks), | |
| "faiss_ntotal": int(getattr(rag.index, "ntotal", 0)), | |
| "model_dim": int(getattr(rag.index, "d", rag.embed_dim)), | |
| "last_added_chunks": len(getattr(rag, "last_added", [])), | |
| "version": app.version, | |
| } | |
| def reset_index(): | |
| try: | |
| rag.index = faiss.IndexFlatIP(rag.embed_dim) | |
| rag.chunks = [] | |
| rag.last_added = [] | |
| for p in [INDEX_DIR / "faiss.index", INDEX_DIR / "meta.npy"]: | |
| try: | |
| os.remove(p) | |
| except FileNotFoundError: | |
| pass | |
| stats.documents_indexed = 0 | |
| stats.questions_answered = 0 | |
| stats.latencies_ms.clear() | |
| stats.last7_questions = deque([0] * 7, maxlen=7) | |
| stats.history.clear() | |
| return {"ok": True} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |