import os import re import json import torch import gradio as gr from transformers import AutoTokenizer, AutoModel from huggingface_hub import hf_hub_download MODEL_ID = "usmannawaz/ocs" MAX_LENGTH = 512 STRIDE = 0 MAX_SEGS_EVAL = 128 SEG_FORWARD_BS = 8 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" os.environ["TOKENIZERS_PARALLELISM"] = "false" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) encoder = AutoModel.from_pretrained(MODEL_ID).to(DEVICE) encoder.eval() hidden = encoder.config.hidden_size label_path = hf_hub_download(MODEL_ID, "label2id.json") with open(label_path, "r", encoding="utf-8") as f: label2id = json.load(f) id2label = {v: k for k, v in label2id.items()} head_path = hf_hub_download(MODEL_ID, "doc_classifier_head.pt") head_state = torch.load(head_path, map_location="cpu") classifier = torch.nn.Linear(hidden, len(label2id)) classifier.load_state_dict(head_state) classifier = classifier.to(DEVICE) classifier.eval() LABEL_INFO = { "Old Church Slavonic": {"century": "9th–11th", "language": "Old Church Slavonic"}, "Church Slavonic": {"century": "12th–17th", "language": "Church Slavonic"}, "New Church Slavonic": {"century": "18th", "language": "New Church Slavonic"}, "Ruthenian": {"century": "15th–18th", "language": "Ruthenian"}, } BIBLE_REF_RE = re.compile( r"""\b(?:[^\W\d_]{1,20}\.?(?:\s+[^\W\d_]{1,20}\.?)?)\s*\d{1,3}\s*[:.,]\s*\d{1,3}(?:\s*[-–]\s*\d{1,3})?\b""", re.UNICODE, ) NOISY_SYMBOLS_RE = re.compile(r"[†‡§¶¦|‖•·◦※◆◇■□▲△▼▽★☆✝︎✟✠]+", re.UNICODE) def preprocess_text(t: str) -> str: if not t: return "" t = BIBLE_REF_RE.sub(" ", t) t = re.sub(r"\b\d+[A-Za-zА-Яа-я]{1,3}\b", " ", t) t = re.sub(r"\b\d{1,3}\b", " ", t) t = re.sub(r"\[\s*\.{1,}\s*\]", " ", t) t = re.sub(r"\[([^\W\d_]{1,3})\]", r"\1", t, flags=re.UNICODE) t = NOISY_SYMBOLS_RE.sub(" ", t) t = re.sub(r"\s+", " ", t).strip() return t.lower() @torch.no_grad() def predict(text: str) -> str: if not text or not text.strip(): return "" text = preprocess_text(str(text)) if not text.strip(): return "" tok = tokenizer( text, truncation=True, max_length=MAX_LENGTH, padding="max_length", return_overflowing_tokens=True, stride=STRIDE, return_tensors=None, ) input_ids = tok["input_ids"] attention_mask = tok["attention_mask"] if MAX_SEGS_EVAL is not None and len(input_ids) > MAX_SEGS_EVAL: input_ids = input_ids[:MAX_SEGS_EVAL] attention_mask = attention_mask[:MAX_SEGS_EVAL] S = len(input_ids) input_ids = torch.tensor(input_ids, dtype=torch.long, device=DEVICE) attention_mask = torch.tensor(attention_mask, dtype=torch.long, device=DEVICE) sum_emb = torch.zeros((hidden,), device=DEVICE, dtype=torch.float32) count = 0 for start in range(0, S, SEG_FORWARD_BS): end = min(S, start + SEG_FORWARD_BS) out = encoder( input_ids=input_ids[start:end, :], attention_mask=attention_mask[start:end, :], ) seg_emb = out.last_hidden_state[:, 0, :] sum_emb += seg_emb.float().sum(dim=0) count += seg_emb.shape[0] doc_emb = sum_emb / max(count, 1) logits = classifier(doc_emb.unsqueeze(0))[0] pred_id = int(torch.argmax(logits).item()) label = id2label[pred_id] info = LABEL_INFO.get(label) if not info: return "—" return f"Language: {info['language']}\nCentury: {info['century']}" demo = gr.Interface( fn=predict, inputs=gr.Textbox(lines=12, placeholder="Paste a document here..."), outputs=gr.Textbox(label="Century & Language"), title="Document Classifier", description="Paste a document and get the predicted historical range.", ) demo.queue() demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)