qwen-aec-reader / serving /chain_cache.py
riosst's picture
Upload serving/chain_cache.py with huggingface_hub
d2092e5 verified
Raw
History Blame Contribute Delete
12.2 kB
"""
AEC AI Reader — Chain Cache (SQLite)
Cache response chains menggunakan semantic similarity.
Strategi:
- Pre-seed dengan seluruh training data saat startup
- Runtime: embed instruksi → cosine search → hit? return cached : LLM inference → cache result
- Cache hit = <200ms, vs LLM inference = 15-40 detik
"""
import sqlite3
import json
import numpy as np
from pathlib import Path
from typing import Optional, Dict, List, Tuple
import time
import hashlib
class ChainCache:
"""SQLite-backed semantic cache for AEC tool chains."""
def __init__(
self,
db_path: str = "chain_cache.sqlite",
similarity_threshold: float = 0.92,
embedding_model_name: str = "intfloat/multilingual-e5-small"
):
self.db_path = db_path
self.similarity_threshold = similarity_threshold
self.embedding_model_name = embedding_model_name
self._embedding_model = None
self._init_db()
def _init_db(self):
"""Inisialisasi database dan tabel."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.executescript("""
CREATE TABLE IF NOT EXISTS chain_cache (
id INTEGER PRIMARY KEY AUTOINCREMENT,
instruction_hash TEXT UNIQUE,
instruction_text TEXT NOT NULL,
instruction_embedding BLOB NOT NULL,
thinking TEXT,
output_json TEXT NOT NULL,
output_type TEXT NOT NULL,
hit_count INTEGER DEFAULT 0,
created_at REAL NOT NULL,
last_hit REAL
);
CREATE INDEX IF NOT EXISTS idx_output_type
ON chain_cache(output_type);
CREATE INDEX IF NOT EXISTS idx_hit_count
ON chain_cache(hit_count DESC);
CREATE INDEX IF NOT EXISTS idx_instruction_hash
ON chain_cache(instruction_hash);
""")
conn.commit()
conn.close()
@property
def embedding_model(self):
"""Lazy load embedding model."""
if self._embedding_model is None:
try:
from sentence_transformers import SentenceTransformer
self._embedding_model = SentenceTransformer(self.embedding_model_name)
print(f"[Cache] Loaded embedding model: {self.embedding_model_name}")
except ImportError:
raise ImportError(
"Install sentence-transformers: pip install sentence-transformers"
)
return self._embedding_model
def _embed(self, text: str) -> np.ndarray:
"""Embed teks menggunakan model multilingual."""
# Prefix "query: " untuk model E5
if "e5" in self.embedding_model_name.lower():
text = f"query: {text}"
embedding = self.embedding_model.encode(text, normalize_embeddings=True)
return embedding.astype(np.float32)
def _hash_instruction(self, text: str) -> str:
"""Hash untuk exact-match check cepat."""
normalized = text.strip().lower()
return hashlib.sha256(normalized.encode("utf-8")).hexdigest()[:16]
def add(
self,
instruction: str,
output: Dict,
output_type: str,
thinking: Optional[str] = None,
) -> bool:
"""Tambah entry ke cache. Return True jika berhasil."""
instruction_hash = self._hash_instruction(instruction)
embedding = self._embed(instruction)
embedding_blob = embedding.tobytes()
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
cursor.execute(
"""INSERT OR IGNORE INTO chain_cache
(instruction_hash, instruction_text, instruction_embedding,
thinking, output_json, output_type, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(
instruction_hash,
instruction,
embedding_blob,
thinking,
json.dumps(output, ensure_ascii=False),
output_type,
time.time(),
)
)
conn.commit()
return cursor.rowcount > 0
except Exception as e:
print(f"[Cache] Error adding entry: {e}")
return False
finally:
conn.close()
def lookup(self, instruction: str) -> Optional[Dict]:
"""
Cari cache match untuk instruksi.
Return cached response jika similarity > threshold, else None.
"""
# 1. Exact match first (instant)
instruction_hash = self._hash_instruction(instruction)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT id, output_json, output_type, thinking FROM chain_cache WHERE instruction_hash = ?",
(instruction_hash,)
)
exact = cursor.fetchone()
if exact:
cache_id, output_json, output_type, thinking = exact
cursor.execute(
"UPDATE chain_cache SET hit_count = hit_count + 1, last_hit = ? WHERE id = ?",
(time.time(), cache_id)
)
conn.commit()
conn.close()
return {
"output_type": output_type,
"output": json.loads(output_json),
"thinking": thinking,
"cache_hit": "exact",
"similarity": 1.0
}
# 2. Semantic similarity search
query_embedding = self._embed(instruction)
cursor.execute(
"SELECT id, instruction_embedding, output_json, output_type, thinking FROM chain_cache"
)
rows = cursor.fetchall()
if not rows:
conn.close()
return None
best_similarity = -1.0
best_row = None
for row_id, emb_blob, output_json, output_type, thinking in rows:
cached_embedding = np.frombuffer(emb_blob, dtype=np.float32)
# Cosine similarity (embeddings sudah normalized)
similarity = float(np.dot(query_embedding, cached_embedding))
if similarity > best_similarity:
best_similarity = similarity
best_row = (row_id, output_json, output_type, thinking)
if best_similarity >= self.similarity_threshold and best_row:
cache_id, output_json, output_type, thinking = best_row
cursor.execute(
"UPDATE chain_cache SET hit_count = hit_count + 1, last_hit = ? WHERE id = ?",
(time.time(), cache_id)
)
conn.commit()
conn.close()
return {
"output_type": output_type,
"output": json.loads(output_json),
"thinking": thinking,
"cache_hit": "semantic",
"similarity": round(best_similarity, 4)
}
conn.close()
return None
def preseed_from_dataset(self, jsonl_path: str) -> int:
"""
Pre-seed cache dari dataset training JSONL.
Return jumlah entries yang berhasil ditambahkan.
"""
count = 0
path = Path(jsonl_path)
if not path.exists():
print(f"[Cache] Dataset not found: {jsonl_path}")
return 0
print(f"[Cache] Pre-seeding from {jsonl_path}...")
with open(path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
try:
data = json.loads(line.strip())
messages = data.get("messages", [])
# Extract instruction (user message)
user_msg = next(
(m["content"] for m in messages if m["role"] == "user"),
None
)
# Extract assistant response
assistant_msg = next(
(m["content"] for m in messages if m["role"] == "assistant"),
None
)
if not user_msg or not assistant_msg:
continue
# Parse thinking and output from assistant message
thinking = None
output_json_str = assistant_msg
if "<think>" in assistant_msg:
think_start = assistant_msg.index("<think>") + len("<think>")
think_end = assistant_msg.index("</think>")
thinking = assistant_msg[think_start:think_end].strip()
output_json_str = assistant_msg[think_end + len("</think>"):].strip()
parsed_output = json.loads(output_json_str)
output_type = parsed_output.get("output_type", "unknown")
output = parsed_output.get("output", {})
if self.add(user_msg, output, output_type, thinking):
count += 1
if line_num % 100 == 0:
print(f"[Cache] Processed {line_num} lines, added {count}...")
except (json.JSONDecodeError, KeyError, ValueError) as e:
continue
print(f"[Cache] Pre-seeded {count} entries from {line_num} lines")
return count
def stats(self) -> Dict:
"""Return cache statistics."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM chain_cache")
total = cursor.fetchone()[0]
cursor.execute("SELECT output_type, COUNT(*) FROM chain_cache GROUP BY output_type")
by_type = dict(cursor.fetchall())
cursor.execute("SELECT SUM(hit_count) FROM chain_cache")
total_hits = cursor.fetchone()[0] or 0
cursor.execute(
"SELECT COUNT(*) FROM chain_cache WHERE hit_count > 0"
)
entries_with_hits = cursor.fetchone()[0]
conn.close()
return {
"total_entries": total,
"by_output_type": by_type,
"total_hits": total_hits,
"entries_with_hits": entries_with_hits,
"hit_rate_estimate": f"{entries_with_hits/max(total,1)*100:.1f}%"
}
def clear(self):
"""Hapus semua cache entries."""
conn = sqlite3.connect(self.db_path)
conn.execute("DELETE FROM chain_cache")
conn.commit()
conn.close()
print("[Cache] Cleared all entries")
# ============================================================
# CLI for testing
# ============================================================
if __name__ == "__main__":
import sys
cache = ChainCache(
db_path="chain_cache.sqlite",
similarity_threshold=0.92
)
if len(sys.argv) > 1 and sys.argv[1] == "preseed":
dataset_path = sys.argv[2] if len(sys.argv) > 2 else "dataset/output/training_data_v2.jsonl"
count = cache.preseed_from_dataset(dataset_path)
print(f"\nPre-seeded {count} entries")
print(json.dumps(cache.stats(), indent=2))
elif len(sys.argv) > 1 and sys.argv[1] == "test":
# Test queries
test_queries = [
"Buat rumah minimalis 2 lantai 3 kamar tidur",
"Desain rumah modern 2 lantai dengan 3 KT",
"Tambah pintu di dinding ruang tamu",
"Hitung RAB proyek ini",
]
for q in test_queries:
print(f"\nQuery: {q}")
result = cache.lookup(q)
if result:
print(f" HIT ({result['cache_hit']}, similarity: {result['similarity']})")
print(f" Type: {result['output_type']}")
else:
print(" MISS")
elif len(sys.argv) > 1 and sys.argv[1] == "stats":
print(json.dumps(cache.stats(), indent=2))
else:
print("Usage:")
print(" python chain_cache.py preseed [dataset_path]")
print(" python chain_cache.py test")
print(" python chain_cache.py stats")