minh-4T commited on
Commit
4fb223c
·
1 Parent(s): c0748b8

upload document from admin

Browse files
api/admin_documents_router.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from typing import Any, Dict, List
4
+
5
+ from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Query, UploadFile
6
+ from fastapi.concurrency import run_in_threadpool
7
+ from sqlalchemy.orm import Session
8
+
9
+ from core.config import MAX_UPLOAD_SIZE_MB, UPLOAD_DIR
10
+ from core.document_db import Document, get_document_db
11
+ from core.document_ingest_service import run_document_ingest_task
12
+
13
+ router = APIRouter(prefix="/admin/documents", tags=["admin-documents"])
14
+
15
+ _ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt"}
16
+
17
+
18
+ class FileTooLargeError(Exception):
19
+ pass
20
+
21
+
22
+ def _save_upload_file_stream(file_obj: Any, destination: str, max_size_bytes: int) -> int:
23
+ total_size = 0
24
+ chunk_size = 1024 * 1024
25
+
26
+ with open(destination, "wb") as output:
27
+ while True:
28
+ chunk = file_obj.read(chunk_size)
29
+ if not chunk:
30
+ break
31
+
32
+ total_size += len(chunk)
33
+ if total_size > max_size_bytes:
34
+ raise FileTooLargeError("Uploaded file exceeds configured maximum size.")
35
+
36
+ output.write(chunk)
37
+
38
+ return total_size
39
+
40
+
41
+ @router.post("/upload")
42
+ async def upload_document(
43
+ background_tasks: BackgroundTasks,
44
+ file: UploadFile = File(...),
45
+ db: Session = Depends(get_document_db),
46
+ ) -> Dict[str, Any]:
47
+ if not file.filename:
48
+ raise HTTPException(status_code=400, detail="File name is required.")
49
+
50
+ extension = os.path.splitext(file.filename)[1].lower()
51
+ if extension not in _ALLOWED_EXTENSIONS:
52
+ raise HTTPException(status_code=400, detail="Unsupported file type. Allowed: .pdf, .docx, .txt")
53
+
54
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
55
+ stored_name = f"{uuid.uuid4()}{extension}"
56
+ stored_path = os.path.abspath(os.path.join(UPLOAD_DIR, stored_name))
57
+ max_size_bytes = MAX_UPLOAD_SIZE_MB * 1024 * 1024
58
+
59
+ try:
60
+ file.file.seek(0)
61
+ size = await run_in_threadpool(
62
+ _save_upload_file_stream,
63
+ file.file,
64
+ stored_path,
65
+ max_size_bytes,
66
+ )
67
+ except FileTooLargeError:
68
+ if os.path.exists(stored_path):
69
+ os.remove(stored_path)
70
+ raise HTTPException(
71
+ status_code=413,
72
+ detail=f"File is too large. Max allowed size is {MAX_UPLOAD_SIZE_MB} MB.",
73
+ )
74
+ except Exception as error:
75
+ if os.path.exists(stored_path):
76
+ os.remove(stored_path)
77
+ raise HTTPException(status_code=500, detail=f"Failed to save file: {error}")
78
+ finally:
79
+ await file.close()
80
+
81
+ document = Document(
82
+ original_name=file.filename,
83
+ stored_name=stored_name,
84
+ path=stored_path,
85
+ mime_type=file.content_type or "application/octet-stream",
86
+ size=size,
87
+ status="pending",
88
+ total_chunks=0,
89
+ )
90
+ db.add(document)
91
+ db.commit()
92
+ db.refresh(document)
93
+
94
+ background_tasks.add_task(run_document_ingest_task, document.id)
95
+
96
+ return {
97
+ "status": "success",
98
+ "document_id": document.id,
99
+ "original_name": document.original_name,
100
+ "stored_name": document.stored_name,
101
+ "path": document.path,
102
+ }
103
+
104
+
105
+ @router.get("/status/{document_id}")
106
+ def get_document_status(document_id: str, db: Session = Depends(get_document_db)) -> Dict[str, Any]:
107
+ document = db.query(Document).filter(Document.id == document_id).first()
108
+ if document is None:
109
+ raise HTTPException(status_code=404, detail="Document not found.")
110
+
111
+ return {
112
+ "status": "success",
113
+ "document_id": document.id,
114
+ "processing_status": document.status,
115
+ "total_chunks": document.total_chunks,
116
+ "error_message": document.error_message,
117
+ "created_at": document.created_at,
118
+ }
119
+
120
+
121
+ @router.get("")
122
+ def list_documents(
123
+ limit: int = Query(default=20, ge=1, le=100),
124
+ offset: int = Query(default=0, ge=0),
125
+ db: Session = Depends(get_document_db),
126
+ ) -> Dict[str, List[Dict[str, Any]]]:
127
+ records = (
128
+ db.query(Document)
129
+ .order_by(Document.created_at.desc())
130
+ .offset(offset)
131
+ .limit(limit)
132
+ .all()
133
+ )
134
+
135
+ return {
136
+ "status": "success",
137
+ "items": [
138
+ {
139
+ "id": doc.id,
140
+ "original_name": doc.original_name,
141
+ "stored_name": doc.stored_name,
142
+ "status": doc.status,
143
+ "total_chunks": doc.total_chunks,
144
+ "created_at": doc.created_at,
145
+ }
146
+ for doc in records
147
+ ],
148
+ }
core/config.py CHANGED
@@ -26,6 +26,10 @@ FINAL_TOP_K = int(os.getenv('FINAL_TOP_K', '3'))
26
 
27
  DATA_DIR = os.getenv('DATA_DIR', 'data')
28
  VECTOR_DIR = os.getenv('VECTOR_DIR', 'vectorstore')
 
 
 
 
29
 
30
  # External service configs
31
  QDRANT_URL = os.getenv('QDRANT_URL')
 
26
 
27
  DATA_DIR = os.getenv('DATA_DIR', 'data')
28
  VECTOR_DIR = os.getenv('VECTOR_DIR', 'vectorstore')
29
+ UPLOAD_DIR = os.getenv('UPLOAD_DIR', 'uploads')
30
+ MAX_UPLOAD_SIZE_MB = int(os.getenv('MAX_UPLOAD_SIZE_MB', '20'))
31
+ QDRANT_COLLECTION = os.getenv('QDRANT_COLLECTION', 'rag_docs')
32
+ DOCUMENTS_DATABASE_URL = os.getenv('DOCUMENTS_DATABASE_URL', 'sqlite:///./rag_metadata.db')
33
 
34
  # External service configs
35
  QDRANT_URL = os.getenv('QDRANT_URL')
core/document_db.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from datetime import datetime, timezone
3
+
4
+ from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text, create_engine
5
+ from sqlalchemy.orm import declarative_base, relationship, sessionmaker
6
+
7
+ from .config import DOCUMENTS_DATABASE_URL
8
+
9
+ Base = declarative_base()
10
+
11
+ _connect_args = {"check_same_thread": False} if DOCUMENTS_DATABASE_URL.startswith("sqlite") else {}
12
+ engine = create_engine(DOCUMENTS_DATABASE_URL, connect_args=_connect_args)
13
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
14
+
15
+
16
+ def utcnow() -> datetime:
17
+ return datetime.now(timezone.utc)
18
+
19
+
20
+ class Document(Base):
21
+ __tablename__ = "documents"
22
+
23
+ id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
24
+ original_name = Column(String(512), nullable=False)
25
+ stored_name = Column(String(512), nullable=False)
26
+ path = Column(String(1024), nullable=False)
27
+ mime_type = Column(String(255), nullable=False)
28
+ size = Column(Integer, nullable=False)
29
+ status = Column(String(32), nullable=False, default="pending")
30
+ total_chunks = Column(Integer, nullable=False, default=0)
31
+ error_message = Column(Text, nullable=True)
32
+ created_at = Column(DateTime(timezone=True), nullable=False, default=utcnow)
33
+
34
+ chunks = relationship("DocumentChunk", back_populates="document", cascade="all, delete-orphan")
35
+
36
+
37
+ class DocumentChunk(Base):
38
+ __tablename__ = "document_chunks"
39
+
40
+ id = Column(Integer, primary_key=True, autoincrement=True)
41
+ document_id = Column(String(36), ForeignKey("documents.id", ondelete="CASCADE"), nullable=False)
42
+ chunk_index = Column(Integer, nullable=False)
43
+ content_preview = Column(String(200), nullable=False)
44
+ qdrant_point_id = Column(String(64), nullable=True)
45
+ created_at = Column(DateTime(timezone=True), nullable=False, default=utcnow)
46
+
47
+ document = relationship("Document", back_populates="chunks")
48
+
49
+
50
+ def init_document_db() -> None:
51
+ Base.metadata.create_all(bind=engine)
52
+
53
+
54
+ def get_document_db():
55
+ db = SessionLocal()
56
+ try:
57
+ yield db
58
+ finally:
59
+ db.close()
core/document_ingest_service.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import uuid
5
+ from datetime import datetime, timezone
6
+ from typing import List
7
+
8
+ from docx import Document as DocxDocument
9
+ from fastapi.concurrency import run_in_threadpool
10
+ from pypdf import PdfReader
11
+ from qdrant_client import QdrantClient
12
+ from qdrant_client.models import Distance, PointStruct, VectorParams
13
+
14
+ from .config import CHUNK_OVERLAP, CHUNK_SIZE, QDRANT_API_KEY, QDRANT_COLLECTION, QDRANT_URL
15
+ from .document_db import Document, DocumentChunk, SessionLocal
16
+ from .models import embeddings
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt"}
21
+ _WHITESPACE_RE = re.compile(r"\s+")
22
+ _TOKEN_RE = re.compile(r"\S+")
23
+
24
+
25
+ def normalize_text(text: str) -> str:
26
+ if not text:
27
+ return ""
28
+
29
+ cleaned = text.replace("\x00", " ")
30
+ cleaned = cleaned.replace("\ufeff", " ")
31
+ cleaned = cleaned.replace("\u200b", " ").replace("\u200c", " ").replace("\u200d", " ")
32
+ cleaned = _WHITESPACE_RE.sub(" ", cleaned)
33
+ return cleaned.strip()
34
+
35
+
36
+ def read_document_content(path: str, extension: str) -> str:
37
+ extension = extension.lower()
38
+ if extension not in _ALLOWED_EXTENSIONS:
39
+ raise ValueError(f"Unsupported file extension: {extension}")
40
+
41
+ if extension == ".pdf":
42
+ reader = PdfReader(path)
43
+ page_texts = [(page.extract_text() or "") for page in reader.pages]
44
+ return "\n".join(page_texts)
45
+
46
+ if extension == ".docx":
47
+ doc = DocxDocument(path)
48
+ paragraphs = [p.text for p in doc.paragraphs if p.text]
49
+
50
+ for table in doc.tables:
51
+ for row in table.rows:
52
+ row_cells = [cell.text.strip() for cell in row.cells]
53
+ if any(row_cells):
54
+ paragraphs.append(" | ".join(row_cells))
55
+
56
+ return "\n".join(paragraphs)
57
+
58
+ with open(path, "r", encoding="utf-8", errors="ignore") as file:
59
+ return file.read()
60
+
61
+
62
+ def chunk_text_by_tokens(text: str, chunk_size: int, overlap: int) -> List[str]:
63
+ if chunk_size <= 0:
64
+ raise ValueError("CHUNK_SIZE must be > 0")
65
+ if overlap < 0:
66
+ raise ValueError("CHUNK_OVERLAP must be >= 0")
67
+ if overlap >= chunk_size:
68
+ raise ValueError("CHUNK_OVERLAP must be smaller than CHUNK_SIZE")
69
+
70
+ tokens = _TOKEN_RE.findall(text)
71
+ if not tokens:
72
+ return []
73
+
74
+ step = chunk_size - overlap
75
+ chunks: List[str] = []
76
+
77
+ for start in range(0, len(tokens), step):
78
+ end = min(start + chunk_size, len(tokens))
79
+ piece = " ".join(tokens[start:end]).strip()
80
+ if piece:
81
+ chunks.append(piece)
82
+ if end >= len(tokens):
83
+ break
84
+
85
+ return chunks
86
+
87
+
88
+ def _ensure_qdrant_collection(client: QdrantClient, vector_size: int) -> None:
89
+ if not client.collection_exists(collection_name=QDRANT_COLLECTION):
90
+ client.create_collection(
91
+ collection_name=QDRANT_COLLECTION,
92
+ vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
93
+ )
94
+
95
+
96
+ def process_document_ingest(document_id: str) -> None:
97
+ db = SessionLocal()
98
+ document = db.query(Document).filter(Document.id == document_id).first()
99
+
100
+ if document is None:
101
+ db.close()
102
+ logger.error("Document not found for ingest: %s", document_id)
103
+ return
104
+
105
+ try:
106
+ document.status = "processing"
107
+ document.error_message = None
108
+ db.commit()
109
+
110
+ _, extension = os.path.splitext(document.stored_name)
111
+ raw_text = read_document_content(document.path, extension)
112
+ normalized = normalize_text(raw_text)
113
+ chunks = chunk_text_by_tokens(normalized, CHUNK_SIZE, CHUNK_OVERLAP)
114
+
115
+ if not chunks:
116
+ raise ValueError("Document has no readable content after normalization.")
117
+
118
+ if not QDRANT_URL:
119
+ raise ValueError("QDRANT_URL is required for ingest.")
120
+
121
+ client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY or None)
122
+ vectors = embeddings.embed_documents(chunks)
123
+
124
+ if not vectors or not vectors[0]:
125
+ raise ValueError("Failed to create embeddings for chunks.")
126
+
127
+ _ensure_qdrant_collection(client, len(vectors[0]))
128
+
129
+ created_at = datetime.now(timezone.utc).isoformat()
130
+ points: List[PointStruct] = []
131
+ db_chunk_rows: List[DocumentChunk] = []
132
+
133
+ for index, (chunk_text, vector) in enumerate(zip(chunks, vectors)):
134
+ point_id = str(uuid.uuid4())
135
+ payload = {
136
+ "document_id": document.id,
137
+ "filename": document.original_name,
138
+ "stored_name": document.stored_name,
139
+ "path": document.path,
140
+ "chunk_index": index,
141
+ "created_at": created_at,
142
+ "content": chunk_text,
143
+ }
144
+
145
+ points.append(PointStruct(id=point_id, vector=vector, payload=payload))
146
+ db_chunk_rows.append(
147
+ DocumentChunk(
148
+ document_id=document.id,
149
+ chunk_index=index,
150
+ content_preview=chunk_text[:200],
151
+ qdrant_point_id=point_id,
152
+ )
153
+ )
154
+
155
+ client.upsert(collection_name=QDRANT_COLLECTION, points=points, wait=True)
156
+
157
+ db.query(DocumentChunk).filter(DocumentChunk.document_id == document.id).delete()
158
+ db.bulk_save_objects(db_chunk_rows)
159
+
160
+ document.total_chunks = len(chunks)
161
+ document.status = "done"
162
+ db.commit()
163
+
164
+ logger.info("Document ingest success. document_id=%s total_chunks=%s", document.id, len(chunks))
165
+ except Exception as error:
166
+ db.rollback()
167
+
168
+ failed_doc = db.query(Document).filter(Document.id == document_id).first()
169
+ if failed_doc is not None:
170
+ failed_doc.status = "failed"
171
+ failed_doc.error_message = str(error)
172
+ db.commit()
173
+
174
+ logger.exception("Document ingest failed. document_id=%s", document_id)
175
+ finally:
176
+ db.close()
177
+
178
+
179
+ async def run_document_ingest_task(document_id: str) -> None:
180
+ # Heavy ingest work runs in threadpool to keep event loop responsive.
181
+ await run_in_threadpool(process_document_ingest, document_id)
core/vectorstore.py CHANGED
@@ -8,7 +8,7 @@ from docx import Document
8
  from .models import embeddings
9
  from .text_utils import clean_text
10
  from .chunking import smart_chunking
11
- from .config import DATA_DIR, VECTOR_DIR, QDRANT_API_KEY, QDRANT_URL
12
  from langchain_core.documents import Document as LangChainDocument
13
  import zipfile
14
  import xml.etree.ElementTree as ET
@@ -25,7 +25,7 @@ logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
  CHUNKS_PICKLE = os.path.join(VECTOR_DIR, "chunks.pkl")
28
- COLLECTION_NAME = "quy_che_db"
29
  # [YEAR-AWARE CHANGE] Ho tro quet de quy va gan metadata nam hoc.
30
  SUPPORTED_FORMATS = ('.pdf', '.doc', '.docx')
31
  ACADEMIC_YEAR_PATTERN = re.compile(r"(20\d{2})\s*[-_]\s*(20\d{2})")
 
8
  from .models import embeddings
9
  from .text_utils import clean_text
10
  from .chunking import smart_chunking
11
+ from .config import DATA_DIR, VECTOR_DIR, QDRANT_API_KEY, QDRANT_URL, QDRANT_COLLECTION
12
  from langchain_core.documents import Document as LangChainDocument
13
  import zipfile
14
  import xml.etree.ElementTree as ET
 
25
  logger = logging.getLogger(__name__)
26
 
27
  CHUNKS_PICKLE = os.path.join(VECTOR_DIR, "chunks.pkl")
28
+ COLLECTION_NAME = QDRANT_COLLECTION
29
  # [YEAR-AWARE CHANGE] Ho tro quet de quy va gan metadata nam hoc.
30
  SUPPORTED_FORMATS = ('.pdf', '.doc', '.docx')
31
  ACADEMIC_YEAR_PATTERN = re.compile(r"(20\d{2})\s*[-_]\s*(20\d{2})")
main.py CHANGED
@@ -11,10 +11,12 @@ import asyncpg
11
  from starlette.concurrency import iterate_in_threadpool
12
  from qdrant_client import QdrantClient
13
  #Import các model và các hàm cần thiết từ core
14
- from core.config import QDRANT_URL, QDRANT_API_KEY, DATABASE_URL
 
15
  from core.vectorstore import build_vectorstore_improved, load_vectorstore_improved
16
  from core.retriever import HybridRetriever
17
  from core.qa_pipeline import ask_ai_improved, ask_ai_stream_delta
 
18
 
19
  # Hàm log lỗi an toàn
20
  logging.basicConfig(level=logging.INFO)
@@ -109,6 +111,8 @@ async def lifespan(app: FastAPI):
109
  logger.info("Đang khởi tạo API SERVER ...")
110
  pool = None
111
  try:
 
 
112
  pool = await asyncpg.create_pool(
113
  dsn=DATABASE_URL,
114
  min_size=POOL_MIN_SIZE,
@@ -118,7 +122,7 @@ async def lifespan(app: FastAPI):
118
  await init_db_asyncpg(pool)
119
 
120
  client = QdrantClient(url = QDRANT_URL, api_key=QDRANT_API_KEY)
121
- collection_name= "quy_che_db"
122
  if not client.collection_exists(collection_name):
123
  logger.warning(f"Chưa có collection {collection_name} trên Qdrant Cloud. Đang xây dựng vectorstore mới...")
124
  db, all_chunks= build_vectorstore_improved()
@@ -151,6 +155,7 @@ def get_runtime_components(request: Request):
151
 
152
  #Cấu hình FastAPI với middleware CORS và lifespan để quản lý trạng thái hệ thống
153
  app = FastAPI(lifespan=lifespan, title= "RAG API SERVER")
 
154
 
155
  #Cho phép truy cập từ mọi nguồn
156
  allow_origins = [origin.strip() for origin in os.getenv("ALLOW_ORIGINS", "*").split(",") if origin.strip()]
 
11
  from starlette.concurrency import iterate_in_threadpool
12
  from qdrant_client import QdrantClient
13
  #Import các model và các hàm cần thiết từ core
14
+ from core.config import QDRANT_URL, QDRANT_API_KEY, DATABASE_URL, QDRANT_COLLECTION
15
+ from core.document_db import init_document_db
16
  from core.vectorstore import build_vectorstore_improved, load_vectorstore_improved
17
  from core.retriever import HybridRetriever
18
  from core.qa_pipeline import ask_ai_improved, ask_ai_stream_delta
19
+ from api.admin_documents_router import router as admin_documents_router
20
 
21
  # Hàm log lỗi an toàn
22
  logging.basicConfig(level=logging.INFO)
 
111
  logger.info("Đang khởi tạo API SERVER ...")
112
  pool = None
113
  try:
114
+ init_document_db()
115
+
116
  pool = await asyncpg.create_pool(
117
  dsn=DATABASE_URL,
118
  min_size=POOL_MIN_SIZE,
 
122
  await init_db_asyncpg(pool)
123
 
124
  client = QdrantClient(url = QDRANT_URL, api_key=QDRANT_API_KEY)
125
+ collection_name = QDRANT_COLLECTION
126
  if not client.collection_exists(collection_name):
127
  logger.warning(f"Chưa có collection {collection_name} trên Qdrant Cloud. Đang xây dựng vectorstore mới...")
128
  db, all_chunks= build_vectorstore_improved()
 
155
 
156
  #Cấu hình FastAPI với middleware CORS và lifespan để quản lý trạng thái hệ thống
157
  app = FastAPI(lifespan=lifespan, title= "RAG API SERVER")
158
+ app.include_router(admin_documents_router)
159
 
160
  #Cho phép truy cập từ mọi nguồn
161
  allow_origins = [origin.strip() for origin in os.getenv("ALLOW_ORIGINS", "*").split(",") if origin.strip()]
requirements.txt CHANGED
@@ -12,6 +12,7 @@ google-generativeai>=0.7.0
12
 
13
  # Database & Vector Store
14
  asyncpg>=0.29.0
 
15
  qdrant-client>=1.9.0
16
 
17
  #Embedding Models & Transformers
 
12
 
13
  # Database & Vector Store
14
  asyncpg>=0.29.0
15
+ sqlalchemy>=2.0.0
16
  qdrant-client>=1.9.0
17
 
18
  #Embedding Models & Transformers