Plaiglab / scripts /train.py
SanidhyaDhangar's picture
PlaigLab — Hugging Face Space (Docker) clean deploy
ebebfe8
Raw
History Blame Contribute Delete
5.88 kB
"""Train all learned components and persist them to models/:
1. Siamese semantic encoder (DL, contrastive loss, pure-numpy backprop)
2. DQN investigation planner (RL, replay buffer + target network)
3. Plagiarism-type classifier (DL, softmax MLP over evidence features)
4. Causal query bandit + Bayesian evidence weights (initial priors)
"""
import json
import os
import sys
import time
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from plagdetect.causal import CausalQueryBandit # noqa: E402
from plagdetect.evidence import EvidenceWeights, compare_section # noqa: E402
from plagdetect.forensics import CLASSES, TypeClassifier # noqa: E402
from plagdetect.ingestion import load_corpus # noqa: E402
from plagdetect.rl_planner import train_dqn # noqa: E402
from plagdetect.siamese import SiameseEncoder # noqa: E402
from plagdetect.textutils import mosaic_mix, sentences, synonymize # noqa: E402
from plagdetect.understanding import build_idf # noqa: E402
from gen_corpus import TOPICS, render_sections # noqa: E402
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CORPUS_DIR = os.path.join(ROOT, "data", "corpus")
MODELS_DIR = os.path.join(ROOT, "models")
def word_dropout(text, rng, p=0.15):
words = text.split()
kept = [w for w in words if rng.rand() > p]
return " ".join(kept if len(kept) > 3 else words)
def build_siamese_pairs(corpus, rng, n_pos=1500, n_neg=1500):
all_sents = []
for d in corpus:
for sec in ("introduction", "literature_review", "methodology", "results"):
all_sents.extend(sentences(d.sections.get(sec, "")))
pairs = []
for _ in range(n_pos):
s = all_sents[rng.randint(len(all_sents))]
t = synonymize(s, rng, p=0.7) if rng.rand() < 0.5 else word_dropout(s, rng)
pairs.append((s, t, 1))
for _ in range(n_neg):
a = all_sents[rng.randint(len(all_sents))]
b = all_sents[rng.randint(len(all_sents))]
if a[:40] != b[:40]:
pairs.append((a, b, 0))
return pairs
def build_classifier_data(corpus, encoder, idf, default_idf, rng, n_per=110):
"""Self-consistent training: transforms -> real evidence features -> label."""
raws = {}
for fn in os.listdir(CORPUS_DIR):
if fn.endswith(".json"):
with open(os.path.join(CORPUS_DIR, fn), "r", encoding="utf-8") as f:
raws[fn[:-5]] = json.load(f)
by_topic = {}
for d in corpus:
by_topic.setdefault(raws[d.doc_id]["topic"], []).append(d)
X, y = [], []
secs = ["literature_review", "methodology", "results"]
docs = list(corpus)
for it in range(n_per):
src = docs[rng.randint(len(docs))]
topic = raws[src.doc_id]["topic"]
sec = secs[rng.randint(len(secs))]
src_text = src.sections[sec]
src_sents = sentences(src_text)
other_topic = [k for k in by_topic if k != topic][rng.randint(len(by_topic) - 1)]
other = by_topic[other_topic][rng.randint(len(by_topic[other_topic]))]
variants = {
"clean": (other.sections[sec], other.references),
"clone": (src_text, list(src.references)),
"find_replace": (synonymize(src_text, rng, p=0.9),
list(src.references[:6])),
"mosaic": (mosaic_mix(src_sents, sentences(other.sections[sec]), rng),
list(src.references[:5]) + list(other.references[:3])),
"idea": (render_sections(TOPICS[topic], raws[src.doc_id]["vals"],
np.random.RandomState(rng.randint(10_000)),
alt=True)[sec],
list(src.references[:6])),
}
for label, (text, refs) in variants.items():
ev = compare_section(text, refs, src, encoder, idf, default_idf)
X.append(ev["features"])
y.append(CLASSES.index(label))
if (it + 1) % 25 == 0:
print(f" classifier data {5 * (it + 1)}/{5 * n_per} samples")
return np.array(X), np.array(y)
def main():
os.makedirs(MODELS_DIR, exist_ok=True)
rng = np.random.RandomState(5)
corpus = load_corpus(CORPUS_DIR)
idf, default_idf = build_idf(corpus)
print(f"[1/4] Siamese encoder -- corpus={len(corpus)} docs")
t0 = time.time()
enc = SiameseEncoder()
enc.train(build_siamese_pairs(corpus, rng), epochs=5)
enc.save(os.path.join(MODELS_DIR, "siamese.npz"))
print(f" saved siamese.npz ({time.time() - t0:.1f}s)")
print("[2/4] DQN investigation planner")
t0 = time.time()
qnet, returns = train_dqn(episodes=900)
qnet.save(os.path.join(MODELS_DIR, "dqn.npz"))
print(f" saved dqn.npz, final avg return "
f"{np.mean(returns[-100:]):+.2f} ({time.time() - t0:.1f}s)")
print("[3/4] Plagiarism-type classifier (features via real evidence pipeline)")
t0 = time.time()
X, y = build_classifier_data(corpus, enc, idf, default_idf, rng)
split = int(0.85 * len(X))
perm = rng.permutation(len(X))
Xtr, ytr = X[perm[:split]], y[perm[:split]]
Xte, yte = X[perm[split:]], y[perm[split:]]
clf = TypeClassifier()
clf.train(Xtr, ytr, epochs=60)
preds = [clf.predict(x)[0] for x in Xte]
acc = float(np.mean([CLASSES.index(p) == t for p, t in zip(preds, yte)]))
clf.save(os.path.join(MODELS_DIR, "classifier.npz"))
print(f" saved classifier.npz, held-out acc={acc:.3f} "
f"({time.time() - t0:.1f}s)")
print("[4/4] Causal bandit priors + Bayesian evidence weights")
CausalQueryBandit().save(os.path.join(MODELS_DIR, "bandit.json"))
EvidenceWeights().save(os.path.join(MODELS_DIR, "weights.json"))
print(" saved bandit.json, weights.json")
print("Training complete.")
if __name__ == "__main__":
main()