bbkdevops's picture
download
raw
11.7 kB
"""
Phase 1C: Quality Filtering — คัดกรองให้บริสุทธิ์สูงสุด
Pipeline:
1. Format filter → ตัด empty, too short, too long
2. Dedup → ลบซ้ำด้วย MinHash LSH
3. Language check → ตรวจว่า Thai จริง / EN จริง
4. Quality score → ให้คะแนน Q+A (length ratio, info density)
5. Reward filter → ใช้โมเดลใน Ollama ตัดสิน "ดีหรือไม่"
6. Balance → สมดุล Thai:EN = 40:60
"""
import hashlib
import json
import re
from pathlib import Path
import requests
from tqdm import tqdm
RAW_DIR = Path(__file__).parent / "raw"
DISTILLED_DIR = Path(__file__).parent / "distilled"
FILTERED_DIR = Path(__file__).parent / "filtered"
FILTERED_DIR.mkdir(exist_ok=True)
OLLAMA_URL = "http://localhost:11434"
# ─── 1. Format Filter ─────────────────────────────────────────────────────────
def format_ok(item: dict) -> bool:
q = item.get("question", "")
a = item.get("answer", "")
if not q or not a:
return False
if len(q) < 5 or len(a) < 3:
return False
if len(q) > 2000 or len(a) > 4000:
return False
# ตัด answer ที่เป็นแค่ตัวเลข/symbol
if re.fullmatch(r"[\d\s\W]+", a):
return False
return True
# ─── 2. MinHash Deduplication ─────────────────────────────────────────────────
def shingles(text: str, k: int = 3) -> set[str]:
text = text.lower().strip()
return {text[i:i+k] for i in range(len(text) - k + 1)}
def minhash(text: str, num_hashes: int = 64) -> list[int]:
s = shingles(text)
sig = []
for i in range(num_hashes):
min_val = float("inf")
for shingle in s:
h = int(hashlib.md5(f"{i}:{shingle}".encode()).hexdigest(), 16)
if h < min_val:
min_val = h
sig.append(int(min_val))
return sig
def jaccard_estimate(sig1: list[int], sig2: list[int]) -> float:
return sum(a == b for a, b in zip(sig1, sig2)) / len(sig1)
def deduplicate(items: list[dict], threshold: float = 0.85) -> list[dict]:
"""ลบซ้ำโดยใช้ MinHash Jaccard similarity บน question"""
print(f" Deduplicating {len(items):,} items (threshold={threshold}) ...")
sigs: list[list[int]] = []
kept: list[dict] = []
for item in tqdm(items):
q = item.get("question", "")
sig = minhash(q, num_hashes=32)
is_dup = False
for prev_sig in sigs[-2000:]: # ตรวจ 2000 ล่าสุด (speed tradeoff)
if jaccard_estimate(sig, prev_sig) >= threshold:
is_dup = True
break
if not is_dup:
sigs.append(sig)
kept.append(item)
print(f" After dedup: {len(kept):,} (removed {len(items)-len(kept):,})")
return kept
# ─── 3. Language Verification ─────────────────────────────────────────────────
THAI_CHARS = re.compile(r"[฀-๿]")
EN_CHARS = re.compile(r"[a-zA-Z]")
def lang_ok(item: dict) -> bool:
lang = item.get("lang", "")
text = (item.get("question", "") + " " + item.get("answer", "")).strip()
thai_ratio = len(THAI_CHARS.findall(text)) / max(len(text), 1)
en_ratio = len(EN_CHARS.findall(text)) / max(len(text), 1)
if lang == "th":
return thai_ratio > 0.15 # ต้องมี Thai อย่างน้อย 15%
elif lang == "en":
return en_ratio > 0.4 and thai_ratio < 0.1
return False
# ─── 4. Quality Scoring ───────────────────────────────────────────────────────
def quality_score(item: dict) -> float:
"""คะแนน 0-1 จาก heuristics"""
q = item.get("question", "")
a = item.get("answer", "")
score = 0.0
# ความยาวที่เหมาะสม
q_len, a_len = len(q), len(a)
if 10 <= q_len <= 300:
score += 0.2
if 10 <= a_len <= 1000:
score += 0.2
# answer มี information (ไม่ใช่แค่ yes/no)
if a_len > 20:
score += 0.2
# คำถามมี question word
q_lower = q.lower()
has_q_word = any(w in q_lower for w in [
"what", "who", "where", "when", "why", "how", "which",
"อะไร", "ใคร", "ที่ไหน", "เมื่อไร", "ทำไม", "อย่างไร", "ไหน",
"คือ", "คืออะไร", "หมายถึง",
])
if has_q_word:
score += 0.2
# ไม่มี hallucination markers
bad_phrases = ["i don't know", "i'm not sure", "as an ai", "ไม่ทราบ", "ไม่แน่ใจ"]
if not any(p in a.lower() for p in bad_phrases):
score += 0.2
return score
# ─── 5. Reward Model Filter (Ollama) ─────────────────────────────────────────
def get_reward_model() -> str | None:
"""หาโมเดลที่ดีที่สุดใน Ollama สำหรับ judge"""
try:
r = requests.get(f"{OLLAMA_URL}/api/tags", timeout=5)
models = [m["name"] for m in r.json().get("models", [])]
# เลือก model ที่ดีที่สุดสำหรับ judge (prefer larger)
preferred = ["llama3", "mistral", "gemma", "qwen", "phi"]
for p in preferred:
for m in models:
if p in m.lower():
return m
return models[0] if models else None
except Exception:
return None
def reward_judge(items: list[dict], model: str, sample_size: int = 500) -> list[dict]:
"""ใช้ LLM judge คัดกรองคุณภาพ — sample เพื่อความเร็ว"""
if not items:
return items
# sample items ที่จะ judge (เอา high-quality-score ก่อน)
to_judge = sorted(items, key=lambda x: x.get("_score", 0), reverse=True)
judged_good: list[dict] = []
not_judged = to_judge[sample_size:]
print(f" Reward judging {min(sample_size, len(to_judge))} samples with {model} ...")
for item in tqdm(to_judge[:sample_size]):
q = item["question"][:200]
a = item["answer"][:400]
lang = item.get("lang", "en")
if lang == "th":
prompt = f"""ประเมินคู่ Q&A นี้: ถ้าถูกต้องและมีประโยชน์ตอบ "GOOD" ถ้าไม่ดีตอบ "BAD"
คำถาม: {q}
คำตอบ: {a}
ตอบ GOOD หรือ BAD เท่านั้น:"""
else:
prompt = f"""Rate this Q&A pair. Reply "GOOD" if accurate and useful, "BAD" if not.
Q: {q}
A: {a}
Reply GOOD or BAD only:"""
try:
r = requests.post(
f"{OLLAMA_URL}/api/generate",
json={"model": model, "prompt": prompt, "stream": False,
"options": {"temperature": 0.1, "num_predict": 10}},
timeout=30,
)
verdict = r.json().get("response", "").strip().upper()
if "GOOD" in verdict:
judged_good.append(item)
except Exception:
judged_good.append(item) # keep on error
# ไม่ได้ judge → เก็บถ้า score สูง
auto_keep = [x for x in not_judged if x.get("_score", 0) >= 0.8]
result = judged_good + auto_keep
print(f" Reward filter: {len(judged_good)}/{min(sample_size, len(to_judge))} passed + {len(auto_keep)} auto-kept")
return result
# ─── 6. Balance Thai:EN ───────────────────────────────────────────────────────
def balance(items: list[dict], th_ratio: float = 0.4) -> list[dict]:
thai = [x for x in items if x.get("lang") == "th"]
en = [x for x in items if x.get("lang") == "en"]
total = len(items)
target_th = int(total * th_ratio)
target_en = total - target_th
# sort by score แล้วเอา top
thai_sorted = sorted(thai, key=lambda x: x.get("_score", 0), reverse=True)
en_sorted = sorted(en, key=lambda x: x.get("_score", 0), reverse=True)
result = thai_sorted[:target_th] + en_sorted[:target_en]
print(f" Balanced: Thai {len(thai_sorted[:target_th]):,} | EN {len(en_sorted[:target_en]):,}")
return result
# ─── Main Pipeline ────────────────────────────────────────────────────────────
def load_all_raw() -> list[dict]:
items: list[dict] = []
for f in list(RAW_DIR.glob("*.jsonl")) + list(DISTILLED_DIR.glob("*.jsonl")):
with open(f, encoding="utf-8") as fp:
for line in fp:
line = line.strip()
if line:
try:
items.append(json.loads(line))
except Exception:
pass
print(f"Loaded {len(items):,} raw items from {RAW_DIR} + {DISTILLED_DIR}")
return items
def run_pipeline():
print("=" * 60)
print("TinyMind — Quality Filter Pipeline")
print("=" * 60)
items = load_all_raw()
# 1. Format
print("\n[1] Format filter ...")
items = [x for x in items if format_ok(x)]
print(f" After format: {len(items):,}")
# 2. Language
print("\n[2] Language verification ...")
items = [x for x in items if lang_ok(x)]
print(f" After lang check: {len(items):,}")
# 3. Quality score
print("\n[3] Quality scoring ...")
for item in items:
item["_score"] = quality_score(item)
items = [x for x in items if x["_score"] >= 0.4]
print(f" After score>=0.4: {len(items):,}")
# 4. Dedup
print("\n[4] Deduplication ...")
items = deduplicate(items, threshold=0.85)
# 5. Reward model
print("\n[5] Reward model filter ...")
judge_model = get_reward_model()
if judge_model:
print(f" Judge model: {judge_model}")
items = reward_judge(items, judge_model, sample_size=1000)
else:
print(" Skipping (no Ollama) — using score threshold only")
items = [x for x in items if x.get("_score", 0) >= 0.6]
# 6. Balance
print("\n[6] Balancing Thai:EN ...")
items = balance(items, th_ratio=0.4)
# Remove internal fields before saving
for item in items:
item.pop("_score", None)
# Save
out_path = FILTERED_DIR / "clean_qa.jsonl"
with open(out_path, "w", encoding="utf-8") as f:
for item in items:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
th = sum(1 for d in items if d.get("lang") == "th")
en = sum(1 for d in items if d.get("lang") == "en")
print(f"\n{'='*60}")
print(f"FINAL DATASET: {len(items):,} pairs")
print(f" Thai: {th:,} | EN: {en:,}")
print(f" Saved → {out_path}")
print("Dataset บริสุทธิ์พร้อมใช้ train!")
if __name__ == "__main__":
run_pipeline()

Xet Storage Details

Size:
11.7 kB
·
Xet hash:
2cb0aa8b37e1acf76c63ce6eed76e2ee58b0ab4497dcd34fc121f138e942b2b5

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.