TWL / db.py
rubentsui's picture
Upload db.py with huggingface_hub
c4f9e86 verified
"""Database query layer for TWL Concordancer."""
import sqlite3
from pathlib import Path
from typing import Iterable
import regex
DEFAULT_DB = Path(__file__).parent / "twl_concordancer.db"
def get_conn(db_path=None):
if db_path is None:
db_path = DEFAULT_DB
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA busy_timeout=5000")
try:
conn.execute("PRAGMA journal_mode=WAL")
except sqlite3.OperationalError:
# Some hosted environments are fine with plain read access but do not
# allow switching journal mode.
pass
return conn
def _expand_category_options(categories: Iterable[str]):
expanded = set()
for category in categories:
category = (category or "").strip()
if not category:
continue
expanded.add(category)
parts = [part.strip() for part in category.split(">") if part.strip()]
if len(parts) <= 1:
continue
for depth in range(1, len(parts)):
expanded.add(">".join(parts[:depth]) + ">")
return sorted(expanded, key=_category_sort_key)
def _category_sort_key(category: str):
category = (category or "").strip()
is_repealed = category.startswith("廢止法規>")
return (1 if is_repealed else 0, category)
def _build_order_by(lang):
if lang == "zh":
return """
ORDER BY
CASE
WHEN s.en_text IS NOT NULL AND trim(s.en_text) != '' THEN 0
ELSE 1
END,
s.alignment_score,
s.id
"""
if lang == "en":
return """
ORDER BY
CASE
WHEN s.zh_text IS NOT NULL AND trim(s.zh_text) != '' THEN 0
ELSE 1
END,
s.alignment_score,
s.id
"""
return """
ORDER BY
CASE
WHEN s.zh_text IS NOT NULL AND trim(s.zh_text) != ''
AND s.en_text IS NOT NULL AND trim(s.en_text) != '' THEN 0
ELSE 1
END,
s.alignment_score,
s.id
"""
def search_sentences(
conn,
query,
use_regex=False,
case_sensitive=False,
law_id=None,
category=None,
article_no=None,
max_score=None,
lang="both",
limit=100,
offset=0,
):
if use_regex:
return _search_regex(
conn,
query,
case_sensitive,
law_id,
category,
article_no,
max_score,
lang,
limit,
offset,
)
return _search_like(
conn,
query,
case_sensitive,
law_id,
category,
article_no,
max_score,
lang,
limit,
offset,
)
def _search_like(
conn, query, case_sensitive, law_id, category, article_no, max_score, lang, limit, offset
):
terms = query.strip()
if not terms:
return [], 0
where = []
params = []
if lang == "zh":
where.append("s.zh_text LIKE ?")
params.append(f"%{terms}%")
elif lang == "en":
if case_sensitive:
where.append("instr(s.en_text, ?) > 0")
params.append(terms)
else:
where.append("instr(lower(s.en_text), lower(?)) > 0")
params.append(terms)
else:
if case_sensitive:
where.append("(s.zh_text LIKE ? OR instr(s.en_text, ?) > 0)")
params.extend([f"%{terms}%", terms])
else:
where.append("(s.zh_text LIKE ? OR instr(lower(s.en_text), lower(?)) > 0)")
params.extend([f"%{terms}%", terms])
if max_score is not None:
where.append("s.alignment_score <= ?")
params.append(max_score)
if law_id:
where.append("l.law_id = ?")
params.append(law_id)
if category:
if category.endswith(">"):
where.append("l.category LIKE ?")
params.append(f"{category}%")
else:
where.append("l.category = ?")
params.append(category)
if article_no:
pat = f"%{article_no}%"
where.append("(a.article_no_zh LIKE ? OR a.article_no_en LIKE ?)")
params.extend([pat, pat])
where_clause = " AND ".join(where)
count_sql = f"""
SELECT count(*) FROM sentences s
JOIN laws l ON s.law_id = l.id
JOIN articles a ON s.article_id = a.id
WHERE {where_clause}
"""
order_by = _build_order_by(lang)
data_sql = f"""
SELECT s.id, s.zh_text, s.en_text, s.alignment_score,
l.law_id, l.zh_name, l.en_name, l.type,
a.article_no_zh, a.article_no_en, a.article_type,
s.zh_sentence_idx, s.en_sentence_idx
FROM sentences s
JOIN laws l ON s.law_id = l.id
JOIN articles a ON s.article_id = a.id
WHERE {where_clause}
{order_by}
LIMIT ? OFFSET ?
"""
data_params = params + [limit, offset]
cur = conn.execute(count_sql, params)
total = cur.fetchone()[0]
cur = conn.execute(data_sql, data_params)
rows = [dict(r) for r in cur.fetchall()]
return rows, total
def _search_regex(
conn,
pattern,
case_sensitive,
law_id,
category,
article_no,
max_score,
lang,
limit,
offset,
):
try:
regex.compile(pattern, flags=regex.V1 if case_sensitive else regex.V1 | regex.IGNORECASE)
except regex.error:
return [], 0
where = "1=1"
params = []
if lang == "zh":
where += " AND s.zh_text REGEXP ?"
params.append(pattern)
elif lang == "en":
where += " AND s.en_text REGEXP ?"
params.append(pattern)
else:
where += " AND (s.zh_text REGEXP ? OR s.en_text REGEXP ?)"
params.extend([pattern, pattern])
if max_score is not None:
where += " AND s.alignment_score <= ?"
params.append(max_score)
if law_id:
where += " AND l.law_id = ?"
params.append(law_id)
if category:
if category.endswith(">"):
where += " AND l.category LIKE ?"
params.append(f"{category}%")
else:
where += " AND l.category = ?"
params.append(category)
if article_no:
where += " AND (a.article_no_zh LIKE ? OR a.article_no_en LIKE ?)"
pat = f"%{article_no}%"
params.extend([pat, pat])
count_sql = f"""
SELECT count(*) FROM sentences s
JOIN laws l ON s.law_id = l.id
JOIN articles a ON s.article_id = a.id
WHERE {where}
"""
order_by = _build_order_by(lang)
data_sql = f"""
SELECT s.id, s.zh_text, s.en_text, s.alignment_score,
l.law_id, l.zh_name, l.en_name, l.type,
a.article_no_zh, a.article_no_en, a.article_type,
s.zh_sentence_idx, s.en_sentence_idx
FROM sentences s
JOIN laws l ON s.law_id = l.id
JOIN articles a ON s.article_id = a.id
WHERE {where}
{order_by}
LIMIT ? OFFSET ?
"""
data_params = params + [limit, offset]
conn.create_function(
"REGEXP",
2,
lambda pat, txt: bool(
regex.search(
pat,
txt,
flags=regex.V1 if case_sensitive else regex.V1 | regex.IGNORECASE,
)
)
if txt
else False,
)
cur = conn.execute(count_sql, params)
total = cur.fetchone()[0]
cur = conn.execute(data_sql, data_params)
rows = [dict(r) for r in cur.fetchall()]
return rows, total
def get_paragraph(conn, sentence_id):
cur = conn.execute(
"""
SELECT s.id, s.zh_text, s.en_text, s.alignment_score, s.zh_sentence_idx, s.en_sentence_idx,
p.paragraph_index, a.article_no_zh, a.article_no_en
FROM sentences s
JOIN paragraphs p ON s.paragraph_id = p.id
JOIN articles a ON s.article_id = a.id
WHERE s.paragraph_id = (SELECT paragraph_id FROM sentences WHERE id = ?)
ORDER BY s.zh_sentence_idx
""",
(sentence_id,),
)
rows = [dict(r) for r in cur.fetchall()]
if not rows:
return None
return {
"paragraph_index": rows[0]["paragraph_index"],
"article_no_zh": rows[0]["article_no_zh"],
"article_no_en": rows[0]["article_no_en"],
"sentences": rows,
}
def get_article(conn, sentence_id):
cur = conn.execute(
"""
SELECT s.id, s.zh_text, s.en_text, s.alignment_score, s.zh_sentence_idx, s.en_sentence_idx,
p.paragraph_index, p.id as paragraph_id,
a.article_no_zh, a.article_no_en, a.article_type
FROM sentences s
JOIN paragraphs p ON s.paragraph_id = p.id
JOIN articles a ON s.article_id = a.id
WHERE s.article_id = (SELECT article_id FROM sentences WHERE id = ?)
ORDER BY p.paragraph_index, s.zh_sentence_idx
""",
(sentence_id,),
)
rows = [dict(r) for r in cur.fetchall()]
if not rows:
return None
paragraphs = {}
for r in rows:
pidx = r["paragraph_index"]
if pidx not in paragraphs:
paragraphs[pidx] = {
"paragraph_index": pidx,
"paragraph_id": r["paragraph_id"],
"sentences": [],
}
paragraphs[pidx]["sentences"].append(
{
"id": r["id"],
"zh_text": r["zh_text"],
"en_text": r["en_text"],
"alignment_score": r["alignment_score"],
"zh_sentence_idx": r["zh_sentence_idx"],
"en_sentence_idx": r["en_sentence_idx"],
}
)
return {
"article_no_zh": rows[0]["article_no_zh"],
"article_no_en": rows[0]["article_no_en"],
"article_type": rows[0]["article_type"],
"paragraphs": [paragraphs[k] for k in sorted(paragraphs.keys())],
}
def list_laws(conn, law_type=None, category=None):
where = []
params = []
if law_type:
where.append("type = ?")
params.append(law_type)
if category:
if category.endswith(">"):
where.append("category LIKE ?")
params.append(f"{category}%")
else:
where.append("category = ?")
params.append(category)
where_clause = " AND ".join(where) if where else "1=1"
cur = conn.execute(
f"SELECT law_id, type, zh_name, en_name, category FROM laws WHERE {where_clause} ORDER BY law_id",
params,
)
return [dict(r) for r in cur.fetchall()]
def list_categories(conn, law_type=None):
if law_type:
where = "WHERE type = ? AND category IS NOT NULL AND category != ''"
params = [law_type]
else:
where = "WHERE category IS NOT NULL AND category != ''"
params = []
cur = conn.execute(
f"SELECT DISTINCT category FROM laws {where} ORDER BY category",
params,
)
return _expand_category_options(r["category"] for r in cur.fetchall())
def get_law_articles(conn, law_id):
cur = conn.execute(
"""
SELECT article_no_zh, article_no_en, article_type, article_index
FROM articles
WHERE law_id = (SELECT id FROM laws WHERE law_id = ?)
ORDER BY article_index
""",
(law_id,),
)
return [dict(r) for r in cur.fetchall()]
def get_law_full_text(conn, law_id):
cur = conn.execute(
"""
SELECT s.zh_text, s.en_text, s.alignment_score,
a.article_no_zh, a.article_no_en, a.article_type, a.article_index,
p.paragraph_index
FROM sentences s
JOIN paragraphs p ON s.paragraph_id = p.id
JOIN articles a ON s.article_id = a.id
JOIN laws l ON s.law_id = l.id
WHERE l.law_id = ?
ORDER BY a.article_index, p.paragraph_index, s.zh_sentence_idx
""",
(law_id,),
)
return [dict(r) for r in cur.fetchall()]