ocs / app.py
usmannawaz's picture
Update app.py
1c80f02 verified
import os
import re
import json
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
MODEL_ID = "usmannawaz/ocs"
MAX_LENGTH = 512
STRIDE = 0
MAX_SEGS_EVAL = 128
SEG_FORWARD_BS = 8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
encoder = AutoModel.from_pretrained(MODEL_ID).to(DEVICE)
encoder.eval()
hidden = encoder.config.hidden_size
label_path = hf_hub_download(MODEL_ID, "label2id.json")
with open(label_path, "r", encoding="utf-8") as f:
label2id = json.load(f)
id2label = {v: k for k, v in label2id.items()}
head_path = hf_hub_download(MODEL_ID, "doc_classifier_head.pt")
head_state = torch.load(head_path, map_location="cpu")
classifier = torch.nn.Linear(hidden, len(label2id))
classifier.load_state_dict(head_state)
classifier = classifier.to(DEVICE)
classifier.eval()
LABEL_INFO = {
"Old Church Slavonic": {"century": "9th–11th", "language": "Old Church Slavonic"},
"Church Slavonic": {"century": "12th–17th", "language": "Church Slavonic"},
"New Church Slavonic": {"century": "18th", "language": "New Church Slavonic"},
"Ruthenian": {"century": "15th–18th", "language": "Ruthenian"},
}
BIBLE_REF_RE = re.compile(
r"""\b(?:[^\W\d_]{1,20}\.?(?:\s+[^\W\d_]{1,20}\.?)?)\s*\d{1,3}\s*[:.,]\s*\d{1,3}(?:\s*[-–]\s*\d{1,3})?\b""",
re.UNICODE,
)
NOISY_SYMBOLS_RE = re.compile(r"[†‡§¶¦|‖•·◦※◆◇■□▲△▼▽★☆✝︎✟✠]+", re.UNICODE)
def preprocess_text(t: str) -> str:
if not t:
return ""
t = BIBLE_REF_RE.sub(" ", t)
t = re.sub(r"\b\d+[A-Za-zА-Яа-я]{1,3}\b", " ", t)
t = re.sub(r"\b\d{1,3}\b", " ", t)
t = re.sub(r"\[\s*\.{1,}\s*\]", " ", t)
t = re.sub(r"\[([^\W\d_]{1,3})\]", r"\1", t, flags=re.UNICODE)
t = NOISY_SYMBOLS_RE.sub(" ", t)
t = re.sub(r"\s+", " ", t).strip()
return t.lower()
@torch.no_grad()
def predict(text: str) -> str:
if not text or not text.strip():
return ""
text = preprocess_text(str(text))
if not text.strip():
return ""
tok = tokenizer(
text,
truncation=True,
max_length=MAX_LENGTH,
padding="max_length",
return_overflowing_tokens=True,
stride=STRIDE,
return_tensors=None,
)
input_ids = tok["input_ids"]
attention_mask = tok["attention_mask"]
if MAX_SEGS_EVAL is not None and len(input_ids) > MAX_SEGS_EVAL:
input_ids = input_ids[:MAX_SEGS_EVAL]
attention_mask = attention_mask[:MAX_SEGS_EVAL]
S = len(input_ids)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=DEVICE)
attention_mask = torch.tensor(attention_mask, dtype=torch.long, device=DEVICE)
sum_emb = torch.zeros((hidden,), device=DEVICE, dtype=torch.float32)
count = 0
for start in range(0, S, SEG_FORWARD_BS):
end = min(S, start + SEG_FORWARD_BS)
out = encoder(
input_ids=input_ids[start:end, :],
attention_mask=attention_mask[start:end, :],
)
seg_emb = out.last_hidden_state[:, 0, :]
sum_emb += seg_emb.float().sum(dim=0)
count += seg_emb.shape[0]
doc_emb = sum_emb / max(count, 1)
logits = classifier(doc_emb.unsqueeze(0))[0]
pred_id = int(torch.argmax(logits).item())
label = id2label[pred_id]
info = LABEL_INFO.get(label)
if not info:
return "—"
return f"Language: {info['language']}\nCentury: {info['century']}"
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=12, placeholder="Paste a document here..."),
outputs=gr.Textbox(label="Century & Language"),
title="Document Classifier",
description="Paste a document and get the predicted historical range.",
)
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)