XGLM_odd_ones / app.py
batheand's picture
Create app.py
1de6f95 verified
import re, numpy as np, torch, gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from scipy.ndimage import uniform_filter1d
MODEL = "facebook/xglm-564M" # upgrade to xglm-2.9B if you get a GPU Space
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tok = AutoTokenizer.from_pretrained(MODEL)
lm = AutoModelForCausalLM.from_pretrained(MODEL).to(DEVICE).eval()
def split_words(text: str):
return re.findall(r"\w+|[^\w\s]", text, flags=re.UNICODE)
@torch.inference_mode()
def word_surprisal(text: str):
enc = tok(text, return_tensors="pt", return_offsets_mapping=True)
ids = enc["input_ids"].to(DEVICE)
offs = enc["offset_mapping"][0].tolist()
out = lm(ids)
logits = out.logits[:, :-1, :]
targets = ids[:, 1:]
logp = torch.log_softmax(logits, dim=-1)
ll = logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # [1, T-1]
nll = (-ll).squeeze(0).cpu().numpy()
sub_offs = offs[1:]
words = split_words(text)
# char spans for words
spans = []
pos = 0
for w in words:
start = text.find(w, pos)
spans.append((start, start+len(w)))
pos = start+len(w)
w_scores = np.zeros(len(words), dtype=float)
for s,(a,b) in zip(nll, sub_offs):
if a==b: continue
for i,(ws,we) in enumerate(spans):
if a>=ws and b<=we:
w_scores[i] += float(s); break
return words, w_scores
def robust_threshold(scores, k=2.5):
if len(scores)==0:
return 1e9
med = float(np.median(scores))
mad = float(np.median(np.abs(scores - med))) + 1e-8
return med + k*mad
def infer(text, k=2.5, smooth=2):
words, scores = word_surprisal(text)
if smooth and smooth>1:
scores = uniform_filter1d(scores, size=smooth, mode="nearest")
thr = robust_threshold(scores, k=k)
flagged = [i for i,s in enumerate(scores) if s>=thr]
out = []
for i,w in enumerate(words):
out.append(f"<mark style='background:#ffb3b3'>{w}</mark>" if i in flagged else w)
md = " ".join(out).replace(" ."," .").replace(" ,"," ,")
table = [(w, float(s)) for w,s in zip(words, scores)]
return md, table
demo = gr.Interface(
fn=infer,
inputs=[
gr.Textbox(lines=3, label="Cümle girin"),
gr.Slider(1.5, 4.0, value=2.5, step=0.1, label="Threshold (median + k*MAD)"),
gr.Slider(1, 7, value=2, step=1, label="Smoothing window (words)")
],
outputs=[
gr.HTML(label="Highlighted"),
gr.Dataframe(headers=["Word","Score"], label="Word surprisal (lower=more normal)")]
,
title="Odd-Word Detector — XGLM (autoregressive)",
description="Autoregressive next-token likelihood per word; flags unusually unlikely words."
)
if __name__ == "__main__":
demo.launch()