kasitbot / embedding_generator.py
snygginghani's picture
Deploy KASITBot RAG chatbot
71e1c4b
"""
================================================================================
Phase 3 β€” Embedding Generation β€” v3
University-Level RAG Pipeline β€” Vector Embedding via OpenAI API
================================================================================
KEY CHANGES vs v2:
βœ… Embedding model is text-embedding-3-large everywhere (was inconsistent).
βœ… Accepts bilingual text (Arabic + English) without translation.
OpenAI's model handles both languages natively.
βœ… Chunk size enforced to 400 chars during normalisation to match preprocessor.
Requirements:
pip install openai python-dotenv
Usage:
python embedding_generator.py
================================================================================
"""
import json
import math
import os
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from openai import OpenAI
# ── Token-safe truncation (Arabic BPE can use 2-3 tokens per character) ───────
_MAX_TOKENS = 8_000 # hard limit is 8192; give a 192-token safety margin
try:
import tiktoken
_enc = tiktoken.get_encoding("cl100k_base")
def _truncate_to_token_limit(text: str) -> tuple:
"""Returns (truncated_text, was_truncated)."""
tokens = _enc.encode(text)
if len(tokens) <= _MAX_TOKENS:
return text, False
return _enc.decode(tokens[:_MAX_TOKENS]), True
except ImportError:
# Fallback: 8000 tokens Γ· 3 tokens-per-char (Arabic worst case) β‰ˆ 2666 chars
_CHAR_LIMIT = _MAX_TOKENS // 3
def _truncate_to_token_limit(text: str) -> tuple:
if len(text) <= _CHAR_LIMIT:
return text, False
return text[:_CHAR_LIMIT], True
# ── Configuration ─────────────────────────────────────────────────────────────
EMBEDDING_MODEL = "text-embedding-3-large" # βœ… unified β€” must match vector_store + app
BATCH_SIZE = 100
OUTPUT_FILE = Path("rag_dataset_with_embeddings.json")
INPUT_FILES: List[Path] = [
Path("rag_dataset.json"),
]
QA_DIR = Path("QandA") # all *.json files here are auto-loaded as Q&A pairs
# ══════════════════════════════════════════════════════════════════════════════
# Schema Normalisation
# ══════════════════════════════════════════════════════════════════════════════
def _first(record: Dict, *keys: str, default: Any = None) -> Any:
for k in keys:
if k in record:
return record[k]
return default
def normalize_record(raw: Dict[str, Any], source_file: Path, index: int) -> Optional[Dict[str, Any]]:
text = _first(raw, "text", "content", "body", "chunk", "passage")
if not text or not str(text).strip():
return None
source = _first(raw, "source", "file", "filename", "document", "doc_name", "origin", default=source_file.name)
chunk_id = _first(raw, "chunk_id", "chunk_index", "id", "index", "idx", default=index)
language = _first(raw, "language", "lang", "locale", default="unknown")
was_translated = _first(raw, "was_translated", "translated", "is_translated", default=False)
canonical_keys = {
"text", "content", "body", "chunk", "passage",
"source", "file", "filename", "document", "doc_name", "origin",
"chunk_id", "chunk_index", "id", "index", "idx",
"language", "lang", "locale",
"was_translated", "translated", "is_translated",
}
extras = {k: v for k, v in raw.items() if k not in canonical_keys}
return {
"text": str(text).strip(),
"source": str(source),
"chunk_id": int(chunk_id) if str(chunk_id).isdigit() else chunk_id,
"language": str(language),
"was_translated": bool(was_translated),
**extras,
}
# ══════════════════════════════════════════════════════════════════════════════
# Q&A File Loader
# Handles the {metadata, qa_pairs: [{question, answer}]} format used in QandA/
# ══════════════════════════════════════════════════════════════════════════════
# Filename-to-doc_type mapping (mirrors rag_preprocessor.py patterns)
_QA_DOC_TYPE_PATTERNS = [
("exam_schedule", ["exam", "mid_exam", "schedual"]),
("office_hours", ["office_hours", "office hours"]),
("academic_calendar", ["calendar"]),
("study_plan", ["study_plan", "study plan"]),
("admissions_fees", ["admission", "fee"]),
("scholarship", ["makruma", "grant", "scholarship"]),
("regulation", ["regulation", "ΨͺΨΉΩ„ΩŠΩ…Ψ§Ψͺ", "Ω‚Ψ§Ω†ΩˆΩ†"]),
("course_records", ["course_record", "grade"]),
("departments", ["department", "major"]),
("faculty_info", ["faculty"]),
("careers", ["career"]),
]
def _qa_detect_doc_type(filename: str) -> str:
name = filename.lower()
for dtype, patterns in _QA_DOC_TYPE_PATTERNS:
if any(p in name for p in patterns):
return dtype
return "qa_pair"
def _lang(text: str) -> str:
"""Fast language label from character ratios."""
ar = sum(1 for c in text if "Ψ€" <= c <= "ΫΏ")
en = sum(1 for c in text if c.isalpha() and c.isascii())
total = ar + en
if total == 0:
return "Unknown"
ratio = ar / total
if ratio > 0.6:
return "Arabic"
if ratio < 0.1:
return "English"
return "Mixed"
def load_qa_files(qa_dir: Path) -> List[Dict[str, Any]]:
"""
Load every *.json file in qa_dir that follows the Q&A format:
{ "metadata": { "title": "...", ... },
"qa_pairs": [ { "question": "...", "answer": "..." }, ... ] }
Each pair becomes one record with:
text = "Question: {q}\\nAnswer: {a}"
source = the json filename
doc_type derived from the filename
section_title from metadata.title
This lets the LLM quote the answer verbatim when a Q&A chunk is retrieved.
"""
if not qa_dir.exists():
print(f" ⚠ QandA directory '{qa_dir}' not found β€” skipping Q&A files.")
return []
all_records: List[Dict[str, Any]] = []
for path in sorted(qa_dir.glob("*.json")):
try:
with open(path, "r", encoding="utf-8") as fh:
data = json.load(fh)
except Exception as exc:
print(f" ⚠ Cannot read '{path.name}': {exc}")
continue
if "qa_pairs" not in data:
print(f" ⚠ '{path.name}' has no 'qa_pairs' key β€” skipping.")
continue
meta = data.get("metadata", {})
section_title = meta.get("title", path.stem.replace("_", " ").title())
doc_type = _qa_detect_doc_type(path.name)
qa_pairs = data["qa_pairs"]
count = 0
for idx, pair in enumerate(qa_pairs, start=1):
q = (pair.get("question") or "").strip()
a = (pair.get("answer") or "").strip()
if not q or not a:
continue
text = f"Question: {q}\nAnswer: {a}"
all_records.append({
"text": text,
"source": path.name,
"chunk_id": idx,
"language": _lang(text),
"was_translated": False,
"doc_type": doc_type,
"section_title": section_title,
})
count += 1
print(f" βœ“ '{path.name}' β†’ {count} Q&A pairs (doc_type={doc_type})")
return all_records
# ══════════════════════════════════════════════════════════════════════════════
# Load & Merge Input Files
# ══════════════════════════════════════════════════════════════════════════════
def load_and_merge(paths: List[Path]) -> List[Dict[str, Any]]:
merged: List[Dict[str, Any]] = []
skipped_files = skipped_records = 0
for path in paths:
if not path.exists():
print(f" ⚠ '{path}' not found β€” skipping.")
skipped_files += 1
continue
with open(path, "r", encoding="utf-8") as fh:
raw_data = json.load(fh)
if isinstance(raw_data, dict):
raw_data = raw_data.get("records") or raw_data.get("data") or list(raw_data.values())
file_ok = 0
for i, raw_rec in enumerate(raw_data):
normalised = normalize_record(raw_rec, path, index=len(merged) + i)
if normalised is None:
skipped_records += 1
else:
merged.append(normalised)
file_ok += 1
print(f" βœ“ '{path}' β†’ {file_ok} records loaded.")
print(f"\n Total merged : {len(merged)} records")
if skipped_files:
print(f" ⚠ Skipped files : {skipped_files}")
if skipped_records:
print(f" ⚠ Skipped records : {skipped_records}")
if not merged:
raise ValueError("[ERROR] No records loaded. Check INPUT_FILES.")
return merged
# ══════════════════════════════════════════════════════════════════════════════
# OpenAI Client
# ══════════════════════════════════════════════════════════════════════════════
def load_client() -> OpenAI:
import os
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", ""))
print(f" βœ“ OpenAI client ready (model: {EMBEDDING_MODEL})")
return client
# ══════════════════════════════════════════════════════════════════════════════
# Embedding Generation
# ══════════════════════════════════════════════════════════════════════════════
def normalize_l2(vec: List[float]) -> List[float]:
norm = math.sqrt(sum(x * x for x in vec))
if norm == 0:
return vec
return [x / norm for x in vec]
def generate_embeddings(records: List[Dict[str, Any]], client: OpenAI, batch_size: int = BATCH_SIZE) -> List[Dict[str, Any]]:
"""
Embed all text chunks with text-embedding-3-large.
Handles Arabic and English natively β€” no translation needed.
Vectors are L2-normalised for cosine similarity via dot product.
text-embedding-3-large has an 8192-token hard limit. Arabic BPE tokens
can be 2-3 tokens per character, so a 18 000-char Arabic passage can be
~36 000 tokens. We use tiktoken (or a conservative char fallback) to
truncate precisely to 8 000 tokens before sending each item to the API.
"""
truncated = 0
texts = [r["text"] for r in records]
total = len(texts)
results = []
start_time = time.time()
print(f"\n Embedding {total} chunks (Arabic + English natively) ...")
for i in range(0, total, batch_size):
raw_batch = texts[i: i + batch_size]
batch = []
for idx, t in enumerate(raw_batch):
safe, was_cut = _truncate_to_token_limit(t)
if was_cut:
src = records[i + idx].get("source", "?")
print(f"\n ⚠ Chunk {i + idx} truncated ({len(t)} chars, >{_MAX_TOKENS} tokens) "
f"β€” source: {src}")
truncated += 1
batch.append(safe)
for attempt in range(3):
try:
response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
break
except Exception as e:
if attempt == 2:
raise
print(f"\n ⚠ API error (attempt {attempt+1}/3): {e} β€” retrying in 5s...")
time.sleep(5)
for j, emb_obj in enumerate(response.data):
vec = normalize_l2(emb_obj.embedding)
record = records[i + j].copy()
record["embedding"] = vec
results.append(record)
done = min(i + batch_size, total)
elapsed = time.time() - start_time
print(f" [{done}/{total}] {done / total * 100:.0f}% ({elapsed:.1f}s)", end="\r")
print(f"\n βœ“ All embeddings generated in {time.time() - start_time:.1f}s.")
if truncated:
print(f" ⚠ {truncated} chunk(s) truncated to fit the 8192-token limit.")
print(f" Re-run rag_preprocessor.py first to generate properly-sized chunks.")
return results
# ══════════════════════════════════════════════════════════════════════════════
# Save Output
# ══════════════════════════════════════════════════════════════════════════════
def save_dataset(records: List[Dict[str, Any]], path: Path) -> None:
with open(path, "w", encoding="utf-8") as fh:
json.dump(records, fh, ensure_ascii=False, indent=2)
print(f" βœ“ Saved to '{path}' ({path.stat().st_size / 1_048_576:.1f} MB)")
# ══════════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════════
def main() -> None:
print("=" * 70)
print(" Phase 3 β€” Embedding Generation v4 (RAG + Q&A merge)")
print(f" Model: {EMBEDDING_MODEL}")
print("=" * 70)
print("\n[STEP 1] Loading RAG dataset ...")
records = load_and_merge(INPUT_FILES)
print(f"\n[STEP 2] Loading Q&A files from '{QA_DIR}/' ...")
qa_records = load_qa_files(QA_DIR)
if qa_records:
# Offset chunk_ids so they don't clash with rag_dataset ids
records.extend(qa_records)
print(f" Total after merge : {len(records)} records "
f"({len(records) - len(qa_records)} RAG + {len(qa_records)} Q&A)")
else:
print(f" No Q&A records found β€” continuing with RAG dataset only.")
print("\n[STEP 3] Initialising OpenAI client ...")
client = load_client()
print("\n[STEP 4] Generating embeddings ...")
enriched = generate_embeddings(records, client)
print("\n[STEP 5] Saving output ...")
save_dataset(enriched, OUTPUT_FILE)
dims = len(enriched[0]["embedding"]) if enriched else 0
ar_count = sum(1 for r in enriched if r.get("language") == "Arabic")
en_count = sum(1 for r in enriched if r.get("language") == "English")
qa_count = sum(1 for r in enriched if r.get("doc_type") == "qa_pair"
or r.get("source", "").endswith(".json") and "qa" in r.get("source", "").lower())
from collections import Counter
dtypes = Counter(r.get("doc_type", "general") for r in enriched)
print("\n" + "=" * 70)
print(f" Done!")
print(f" Total embedded : {len(enriched)}")
print(f" Arabic chunks : {ar_count}")
print(f" English chunks : {en_count}")
print(f" Embedding dims : {dims}")
print(f" Output file : {OUTPUT_FILE}")
print(f"\n By document type:")
for dt, cnt in dtypes.most_common():
print(f" {dt:<22}: {cnt:>4}")
print("=" * 70)
if __name__ == "__main__":
main()