g8-cs106 / step3_encode_dataset.py
gracephamit's picture
Upload 29 files
0dd9600 verified
# step3_encode_dataset_hybrid.py
import json
import os
import numpy as np
import pickle
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from underthesea import word_tokenize
import re
MODEL_NAME = "keepitreal/vietnamese-sbert"
INPUT_JSON = "dataset/knowledge_base.json"
OUT_DIR = "artifacts"
VECTORS_NPY = os.path.join(OUT_DIR, "kb_vectors.npy")
META_JSON = os.path.join(OUT_DIR, "kb_meta.json")
BM25_PKL = os.path.join(OUT_DIR, "bm25_index.pkl")
TOKENIZED_PKL = os.path.join(OUT_DIR, "tokenized_corpus.pkl")
def preprocess_vietnamese_text(text: str) -> str:
"""Chuẩn hóa text tiếng Việt"""
if not text:
return ""
text = text.lower()
# Giữ dấu tiếng Việt
text = re.sub(r'[^\w\sàáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', ' ', text)
text = ' '.join(text.split())
return text
def extract_keywords(item: dict) -> list:
keywords = []
# Prioritize topic/chapter as exact phrases
topic = item.get("topic", "").lower()
if topic:
keywords.append(topic) # Keep full topic as one keyword
chapter = item.get("chapter", "").lower()
if chapter:
keywords.append(chapter)
# Add important bi-grams from content
content = item.get("content_for_embedding", "").lower()
if content:
words = content.split()
# Add 2-word phrases
for i in range(len(words)-1):
keywords.append(f"{words[i]} {words[i+1]}")
return keywords[:30] # Limit to avoid noise
def extract_text_for_embedding(item: dict) -> str:
"""Trích xuất text cho embedding"""
texts = []
topic = item.get("topic", "").strip()
if topic:
texts.append(f"Chủ đề: {topic}")
content = item.get("content_for_embedding", "").strip()
if content:
texts.append(content)
metadata = item.get("metadata", {})
if isinstance(metadata, dict):
raw_text = metadata.get("raw_text", "").strip()
if raw_text and raw_text != content:
texts.append(raw_text)
chapter = item.get("chapter", "").strip()
if chapter:
texts.append(f"Thuộc: {chapter}")
combined = ". ".join(texts)
return preprocess_vietnamese_text(combined)
def main():
os.makedirs(OUT_DIR, exist_ok=True)
# Load dataset
with open(INPUT_JSON, "r", encoding="utf-8") as f:
data = json.load(f)
print(f"📊 Processing {len(data)} items...")
# Extract texts and metadata
texts = []
meta = []
all_keywords = []
for idx, item in enumerate(data):
item_id = item.get("id", f"idx_{idx}")
text = extract_text_for_embedding(item)
keywords = extract_keywords(item)
if not text or len(text) < 10:
print(f"⚠️ Warning: Item {item_id} has insufficient text")
continue
texts.append(text)
all_keywords.append(keywords)
meta.append({
"index": len(texts) - 1,
"id": item_id,
"topic": item.get("topic", ""),
"chapter": item.get("chapter", ""),
"knowledge_type": item.get("metadata", {}).get("knowledge_type", ""),
"keywords": keywords,
"text_length": len(text)
})
print(f"📏 Avg text length: {np.mean([m['text_length'] for m in meta]):.0f} chars")
print(f"🔑 Avg keywords: {np.mean([len(k) for k in all_keywords]):.1f} per item")
# ===== 1. Semantic Embeddings =====
print(f"\n🤖 Loading model: {MODEL_NAME}")
model = SentenceTransformer(MODEL_NAME)
print("🔄 Encoding semantic vectors...")
vectors = model.encode(
texts,
batch_size=32,
show_progress_bar=True,
normalize_embeddings=True,
convert_to_numpy=True
)
vectors = np.asarray(vectors, dtype=np.float32)
# ===== 2. BM25 Index =====
print("\n📝 Building BM25 index...")
tokenized_corpus = []
for text in texts:
try:
# Tokenize tiếng Việt
tokens = word_tokenize(text, format="text").split()
except:
# Fallback: simple split
tokens = text.split()
tokenized_corpus.append(tokens)
bm25 = BM25Okapi(tokenized_corpus)
# ===== 3. Save Everything =====
print("\n💾 Saving artifacts...")
np.save(VECTORS_NPY, vectors)
with open(META_JSON, "w", encoding="utf-8") as f:
json.dump(meta, f, ensure_ascii=False, indent=2)
with open(BM25_PKL, "wb") as f:
pickle.dump(bm25, f)
with open(TOKENIZED_PKL, "wb") as f:
pickle.dump(tokenized_corpus, f)
print("\n✅ Step 3 DONE (Hybrid)")
print(f"📦 Items: {len(texts)}")
print(f"📐 Vector shape: {vectors.shape}")
print(f"💾 Saved:")
print(f" - {VECTORS_NPY}")
print(f" - {META_JSON}")
print(f" - {BM25_PKL}")
print(f" - {TOKENIZED_PKL}")
if __name__ == "__main__":
main()