vanfun_be / searchdata.py
moonbaek's picture
Upload 118 files
67819f1 verified
Raw
History Blame Contribute Delete
21.3 kB
"""
DATA & VECTOR DB ENGINEER - Người số 1
Nhiệm vụ: Xử lý đầu vào và trí nhớ cho hệ thống
1. Kết nối Google Drive API tải file đáp án (HỖ TRỢ SUB-FOLDER)
2. Xử lý thô văn bản
3. Chunking (cắt nhỏ)
4. Embedding và đẩy lên Qdrant
5. Viết hàm search_context()
"""
import os
import re
import logging
from typing import List, Dict, Optional
from dataclasses import dataclass
import unicodedata
from pathlib import Path
# === CORE PACKAGES ===
import numpy as np
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance, VectorParams, PointStruct
)
# === GOOGLE DRIVE - DÙNG SERVICE ACCOUNT ===
from google.oauth2 import service_account
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload
# === LOGGING ===
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# ============================================================
# 1. CẤU HÌNH
# ============================================================
@dataclass
class Config:
"""Cấu hình cho Data & Vector DB Engineer"""
# Google Drive - Dùng Service Account
GOOGLE_CREDENTIALS_FILE: str = "service-account-key.json"
GOOGLE_FOLDER_ID: str = "1RLjyoxo88y0wpQNgTG4uhFPdWQyRBmIn"
# Qdrant
QDRANT_HOST: str = "localhost"
QDRANT_PORT: int = 6333
QDRANT_COLLECTION_NAME: str = "van_mau"
# Embedding Model
EMBEDDING_MODEL_NAME: str = "dangvantuan/vietnamese-embedding"
# Chunking
CHUNK_SIZE: int = 300
CHUNK_OVERLAP: int = 50
# Paths
RAW_DIR: str = "data/raw"
def __post_init__(self):
Path(self.RAW_DIR).mkdir(parents=True, exist_ok=True)
# ============================================================
# 2. XỬ LÝ VĂN BẢN
# ============================================================
class TextProcessor:
"""Xử lý thô văn bản: Chuẩn hóa font, xóa khoảng trắng thừa"""
@staticmethod
def normalize_unicode(text: str) -> str:
return unicodedata.normalize('NFC', text)
@staticmethod
def remove_extra_spaces(text: str) -> str:
text = re.sub(r'\s+', ' ', text)
return text.strip()
@staticmethod
def clean_text(text: str) -> str:
text = TextProcessor.normalize_unicode(text)
text = TextProcessor.remove_extra_spaces(text)
return text
# ============================================================
# 3. CHUNKING
# ============================================================
class ChunkingStrategy:
"""Chiến thuật cắt bài văn chuẩn thành đoạn nhỏ"""
def __init__(self, chunk_size: int = 300, overlap: int = 50):
self.chunk_size = chunk_size
self.overlap = overlap
def chunk_smart(self, text: str) -> List[str]:
"""Chiến thuật thông minh: ưu tiên cắt theo đoạn văn"""
paragraphs = re.split(r'\n\s*\n', text)
paragraphs = [p.strip() for p in paragraphs if p.strip()]
chunks = []
for para in paragraphs:
if len(para) <= self.chunk_size:
chunks.append(para)
else:
sub_chunks = self._chunk_by_size(para)
chunks.extend(sub_chunks)
return chunks
def _chunk_by_size(self, text: str) -> List[str]:
"""Cắt theo kích thước cố định với chồng lấn"""
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = min(start + self.chunk_size, text_len)
if end < text_len:
for sep in ['. ', '; ', ', ', ' ']:
pos = text.rfind(sep, start, end)
if pos != -1:
end = pos + len(sep)
break
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start = max(start + 1, end - self.overlap)
return chunks
# ============================================================
# 4. GOOGLE DRIVE API - HỖ TRỢ SUB-FOLDER
# ============================================================
class GoogleDriveManager:
"""Kết nối Google Drive API - Hỗ trợ đệ quy vào sub-folder"""
def __init__(self, credentials_file: str = "service-account-key.json"):
self.credentials_file = credentials_file
self.service = None
self._authenticate()
def _authenticate(self):
"""Xác thực với Google Drive API - Dùng Service Account"""
try:
if not os.path.exists(self.credentials_file):
raise FileNotFoundError(
f"Không tìm thấy file {self.credentials_file}! "
"Hãy tạo Service Account và tải key JSON về."
)
creds = service_account.Credentials.from_service_account_file(
self.credentials_file,
scopes=['https://www.googleapis.com/auth/drive.readonly']
)
self.service = build('drive', 'v3', credentials=creds)
logger.info("✅ Google Drive authenticated with Service Account.")
except Exception as e:
logger.error(f"❌ Authentication failed: {e}")
raise
def list_all_files_recursive(self, folder_id: str, prefix: str = "") -> List[Dict]:
"""
🔥 ĐỆ QUY - Liệt kê TẤT CẢ file trong folder và tất cả sub-folder
"""
all_files = []
try:
query = f"'{folder_id}' in parents and trashed = false"
results = self.service.files().list(
q=query,
fields="files(id, name, mimeType, size, createdTime, modifiedTime)",
pageSize=100
).execute()
items = results.get('files', [])
for item in items:
# Nếu là folder -> đệ quy vào sâu bên trong
if item.get('mimeType') == 'application/vnd.google-apps.folder':
logger.info(f"📁 Đang đào sâu vào: {prefix}{item['name']}/")
sub_files = self.list_all_files_recursive(
item['id'],
prefix=f"{prefix}{item['name']}/"
)
all_files.extend(sub_files)
else:
# Nếu là file -> thêm vào danh sách
item['full_path'] = f"{prefix}{item['name']}"
all_files.append(item)
logger.info(f" 📄 Found: {item['full_path']} (ID: {item['id']})")
except Exception as e:
logger.error(f"❌ Error listing files in folder {folder_id}: {e}")
return all_files
def download_all_files_recursive(self, folder_id: str, destination: str) -> List[str]:
"""
🔥 Tải TẤT CẢ file trong folder và tất cả sub-folder
Giữ nguyên cấu trúc thư mục
"""
logger.info("🔍 Đang quét toàn bộ folder và sub-folder...")
all_files = self.list_all_files_recursive(folder_id)
if not all_files:
logger.warning("⚠️ Không tìm thấy file nào trong folder hoặc sub-folder.")
return []
logger.info(f"📊 Tổng số file tìm thấy: {len(all_files)}")
downloaded_files = []
for file_info in all_files:
full_path = file_info.get('full_path', file_info['name'])
file_path = os.path.join(destination, full_path)
# Tạo thư mục cha nếu chưa tồn tại
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# Bỏ qua Google Workspace files (Docs, Sheets, v.v.)
if file_info.get('mimeType', '').startswith('application/vnd.google-apps'):
logger.warning(f"⚠️ Skipping Google Workspace file: {full_path}")
continue
try:
request = self.service.files().get_media(fileId=file_info['id'])
with open(file_path, 'wb') as f:
downloader = MediaIoBaseDownload(f, request)
done = False
while not done:
status, done = downloader.next_chunk()
logger.info(f"⬇️ Downloading {full_path}: {int(status.progress() * 100)}%")
downloaded_files.append(file_path)
logger.info(f"✅ Downloaded: {full_path}")
except Exception as e:
logger.error(f"❌ Error downloading {full_path}: {e}")
return downloaded_files
# Wrapper để tương thích với code cũ
def download_all_files(self, folder_id: str, destination: str) -> List[str]:
return self.download_all_files_recursive(folder_id, destination)
# ============================================================
# 5. ĐỌC FILE
# ============================================================
class DocumentReader:
"""Đọc nội dung file đáp án"""
@staticmethod
def read_file(file_path: str) -> str:
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"❌ Error reading {file_path}: {e}")
return ""
# ============================================================
# 6. EMBEDDING
# ============================================================
class EmbeddingModel:
"""Biến văn bản thành Vector"""
def __init__(self, model_name: str = "dangvantuan/vietnamese-embedding"):
logger.info(f"🧠 Loading embedding model: {model_name}...")
self.model = SentenceTransformer(model_name)
self.vector_size = self.model.get_sentence_embedding_dimension()
logger.info(f"✅ Model loaded. Vector size: {self.vector_size}")
def encode(self, texts: List[str]) -> np.ndarray:
if isinstance(texts, str):
texts = [texts]
return self.model.encode(texts, convert_to_numpy=True)
# ============================================================
# 7. QDRANT - ĐÃ SỬA LỖI
# ============================================================
class QdrantManager:
"""Quản lý vector database Qdrant"""
def __init__(self, host: str = "localhost", port: int = 6333,
collection_name: str = "van_mau", vector_size: int = 384):
self.client = QdrantClient(path="./qdrant_data")
self.collection_name = collection_name
self.vector_size = vector_size
logger.info(f"✅ Connected to Qdrant at {host}:{port}")
def create_collection(self, force: bool = False):
"""Tạo collection để lưu vector"""
collections = self.client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists:
if force:
self.client.delete_collection(self.collection_name)
logger.info(f"🗑️ Deleted existing collection: {self.collection_name}")
else:
logger.info(f"Collection '{self.collection_name}' already exists.")
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.vector_size,
distance=Distance.COSINE
)
)
logger.info(f"✅ Collection '{self.collection_name}' created.")
def upsert_chunks(self, chunks: List[Dict], embedding_model: EmbeddingModel):
"""Đẩy vector lên Qdrant"""
if not chunks:
logger.warning("No chunks to upsert.")
return
contents = [chunk['content'] for chunk in chunks]
embeddings = embedding_model.encode(contents)
points = []
for i, chunk in enumerate(chunks):
point = PointStruct(
id=i,
vector=embeddings[i].tolist(),
payload={
"content": chunk['content'],
"document_id": chunk['document_id'],
"file_name": chunk.get('file_name', ''),
"chunk_index": chunk.get('chunk_index', i)
}
)
points.append(point)
self.client.upsert(
collection_name=self.collection_name,
points=points
)
logger.info(f"✅ Upserted {len(points)} chunks to Qdrant.")
def search(self, query_vector: List[float], limit: int = 3) -> List[Dict]:
"""
Tìm kiếm vector tương tự trong Qdrant
🔥 ĐÃ SỬA: Dùng query_points() cho phiên bản mới
"""
try:
# Thử dùng API mới (qdrant-client >= 1.8.0)
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
limit=limit
)
return [
{
"content": hit.payload["content"],
"score": hit.score,
"document_id": hit.payload.get("document_id", ""),
"file_name": hit.payload.get("file_name", "")
}
for hit in results.points
]
except AttributeError:
# Fallback cho API cũ (qdrant-client < 1.8.0)
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
limit=limit
)
return [
{
"content": hit.payload["content"],
"score": hit.score,
"document_id": hit.payload.get("document_id", ""),
"file_name": hit.payload.get("file_name", "")
}
for hit in results
]
# ============================================================
# 8. HÀM search_context
# ============================================================
class ContextRetriever:
"""Truy xuất ngữ cảnh đáp án chuẩn từ Qdrant"""
def __init__(self, qdrant: QdrantManager, embedding: EmbeddingModel):
self.qdrant = qdrant
self.embedding = embedding
self.text_processor = TextProcessor()
def search_context(self, bai_van_hoc_sinh: str, limit: int = 3) -> List[str]:
"""
Hàm chính: Đưa vào bài văn học sinh, trả ra đoạn đáp án chuẩn tương ứng
"""
cleaned_query = self.text_processor.clean_text(bai_van_hoc_sinh)
query_vector = self.embedding.encode([cleaned_query])[0]
results = self.qdrant.search(
query_vector=query_vector.tolist(),
limit=limit
)
contexts = [result['content'] for result in results]
logger.info(f"🔍 Found {len(contexts)} relevant context chunks.")
return contexts
# ============================================================
# 9. PIPELINE CHÍNH
# ============================================================
class DataPipeline:
"""Pipeline của Data & Vector DB Engineer"""
def __init__(self, config: Config):
self.config = config
self.text_processor = TextProcessor()
self.chunking = ChunkingStrategy(
chunk_size=config.CHUNK_SIZE,
overlap=config.CHUNK_OVERLAP
)
self.gdrive = GoogleDriveManager(
credentials_file=config.GOOGLE_CREDENTIALS_FILE
)
self.document_reader = DocumentReader()
self.embedding_model = EmbeddingModel(config.EMBEDDING_MODEL_NAME)
self.qdrant = QdrantManager(
host=config.QDRANT_HOST,
port=config.QDRANT_PORT,
collection_name=config.QDRANT_COLLECTION_NAME,
vector_size=self.embedding_model.vector_size
)
def run(self, force_reload: bool = False):
"""Chạy toàn bộ pipeline xử lý dữ liệu"""
logger.info("=" * 60)
logger.info(" DATA & VECTOR DB ENGINEER - BẮT ĐẦU")
logger.info("=" * 60)
# Nhiệm vụ 1: Tải file từ Google Drive (ĐỆ QUY)
logger.info("\n⬇️ Nhiệm vụ 1: Tải file đáp án từ Google Drive")
downloaded_files = self.gdrive.download_all_files_recursive(
self.config.GOOGLE_FOLDER_ID,
self.config.RAW_DIR
)
if not downloaded_files:
logger.error("❌ Không có file nào được tải về.")
logger.error(" Kiểm tra: 1) Folder ID đúng, 2) Service Account đã được share quyền.")
return
# Nhiệm vụ 2 & 3: Xử lý văn bản và Chunking
logger.info("\n Nhiệm vụ 2 & 3: Xử lý văn bản và Chunking")
all_chunks = []
for file_path in downloaded_files:
# Chỉ xử lý file text (có đuôi .txt, .md, .csv, v.v.)
if not any(file_path.endswith(ext) for ext in ['.txt', '.md', '.csv', '.json', '.html', '.xml']):
logger.info(f"⏭️ Bỏ qua file không phải text: {os.path.basename(file_path)}")
continue
raw_text = self.document_reader.read_file(file_path)
if not raw_text:
continue
cleaned_text = self.text_processor.clean_text(raw_text)
chunks_text = self.chunking.chunk_smart(cleaned_text)
file_name = os.path.basename(file_path)
for i, chunk_text in enumerate(chunks_text):
if len(chunk_text.strip()) < 10:
continue
all_chunks.append({
'content': chunk_text,
'document_id': f"doc_{len(all_chunks)}",
'file_name': file_name,
'chunk_index': i
})
logger.info(f"✅ Processed {file_name}: {len(chunks_text)} chunks")
# Nhiệm vụ 4: Embedding và đẩy lên Qdrant
logger.info("\n🧠 Nhiệm vụ 4: Embedding và đẩy lên Qdrant")
self.qdrant.create_collection(force=force_reload)
self.qdrant.upsert_chunks(all_chunks, self.embedding_model)
logger.info("\n" + "=" * 60)
logger.info(f"✅ DATA PIPELINE HOÀN TẤT!")
logger.info(f"📊 Tổng số chunks đã xử lý: {len(all_chunks)}")
logger.info("=" * 60)
def get_retriever(self) -> ContextRetriever:
return ContextRetriever(self.qdrant, self.embedding_model)
# ============================================================
# 10. MAIN
# ============================================================
def main():
"""Chạy pipeline của Data & Vector DB Engineer"""
config = Config()
print("\n" + "="*60)
print("👤 DATA & VECTOR DB ENGINEER - Người số 1")
print("="*60)
print(f"\n📁 Google Drive Folder ID: {config.GOOGLE_FOLDER_ID}")
print(f"📄 Credentials: {config.GOOGLE_CREDENTIALS_FILE}")
print(f"🗄️ Qdrant: {config.QDRANT_HOST}:{config.QDRANT_PORT}")
print(f"📚 Collection: {config.QDRANT_COLLECTION_NAME}")
print(f"🧠 Embedding Model: {config.EMBEDDING_MODEL_NAME}")
print(f"✂️ Chunk size: {config.CHUNK_SIZE} (overlap: {config.CHUNK_OVERLAP})")
print("="*60 + "\n")
# Chạy pipeline
pipeline = DataPipeline(config)
pipeline.run(force_reload=True)
# Test thử search_context
print("\n" + "="*60)
print("🧪 TEST search_context")
print("="*60)
retriever = pipeline.get_retriever()
test_query = "Hãy phân tích nhân vật Chí Phèo"
results = retriever.search_context(test_query, limit=2)
print(f"\n🔍 Query: '{test_query}'")
print(f"📝 Found {len(results)} results:\n")
for i, result in enumerate(results, 1):
print(f"{i}. {result[:300]}...\n")
if __name__ == "__main__":
main()