Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| from huggingface_hub import hf_hub_download | |
| MODEL_ID = "usmannawaz/ocs" | |
| MAX_LENGTH = 512 | |
| STRIDE = 0 | |
| MAX_SEGS_EVAL = 128 | |
| SEG_FORWARD_BS = 8 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) | |
| encoder = AutoModel.from_pretrained(MODEL_ID).to(DEVICE) | |
| encoder.eval() | |
| hidden = encoder.config.hidden_size | |
| label_path = hf_hub_download(MODEL_ID, "label2id.json") | |
| with open(label_path, "r", encoding="utf-8") as f: | |
| label2id = json.load(f) | |
| id2label = {v: k for k, v in label2id.items()} | |
| head_path = hf_hub_download(MODEL_ID, "doc_classifier_head.pt") | |
| head_state = torch.load(head_path, map_location="cpu") | |
| classifier = torch.nn.Linear(hidden, len(label2id)) | |
| classifier.load_state_dict(head_state) | |
| classifier = classifier.to(DEVICE) | |
| classifier.eval() | |
| LABEL_INFO = { | |
| "Old Church Slavonic": {"century": "9th–11th", "language": "Old Church Slavonic"}, | |
| "Church Slavonic": {"century": "12th–17th", "language": "Church Slavonic"}, | |
| "New Church Slavonic": {"century": "18th", "language": "New Church Slavonic"}, | |
| "Ruthenian": {"century": "15th–18th", "language": "Ruthenian"}, | |
| } | |
| BIBLE_REF_RE = re.compile( | |
| r"""\b(?:[^\W\d_]{1,20}\.?(?:\s+[^\W\d_]{1,20}\.?)?)\s*\d{1,3}\s*[:.,]\s*\d{1,3}(?:\s*[-–]\s*\d{1,3})?\b""", | |
| re.UNICODE, | |
| ) | |
| NOISY_SYMBOLS_RE = re.compile(r"[†‡§¶¦|‖•·◦※◆◇■□▲△▼▽★☆✝︎✟✠]+", re.UNICODE) | |
| def preprocess_text(t: str) -> str: | |
| if not t: | |
| return "" | |
| t = BIBLE_REF_RE.sub(" ", t) | |
| t = re.sub(r"\b\d+[A-Za-zА-Яа-я]{1,3}\b", " ", t) | |
| t = re.sub(r"\b\d{1,3}\b", " ", t) | |
| t = re.sub(r"\[\s*\.{1,}\s*\]", " ", t) | |
| t = re.sub(r"\[([^\W\d_]{1,3})\]", r"\1", t, flags=re.UNICODE) | |
| t = NOISY_SYMBOLS_RE.sub(" ", t) | |
| t = re.sub(r"\s+", " ", t).strip() | |
| return t.lower() | |
| def predict(text: str) -> str: | |
| if not text or not text.strip(): | |
| return "" | |
| text = preprocess_text(str(text)) | |
| if not text.strip(): | |
| return "" | |
| tok = tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| padding="max_length", | |
| return_overflowing_tokens=True, | |
| stride=STRIDE, | |
| return_tensors=None, | |
| ) | |
| input_ids = tok["input_ids"] | |
| attention_mask = tok["attention_mask"] | |
| if MAX_SEGS_EVAL is not None and len(input_ids) > MAX_SEGS_EVAL: | |
| input_ids = input_ids[:MAX_SEGS_EVAL] | |
| attention_mask = attention_mask[:MAX_SEGS_EVAL] | |
| S = len(input_ids) | |
| input_ids = torch.tensor(input_ids, dtype=torch.long, device=DEVICE) | |
| attention_mask = torch.tensor(attention_mask, dtype=torch.long, device=DEVICE) | |
| sum_emb = torch.zeros((hidden,), device=DEVICE, dtype=torch.float32) | |
| count = 0 | |
| for start in range(0, S, SEG_FORWARD_BS): | |
| end = min(S, start + SEG_FORWARD_BS) | |
| out = encoder( | |
| input_ids=input_ids[start:end, :], | |
| attention_mask=attention_mask[start:end, :], | |
| ) | |
| seg_emb = out.last_hidden_state[:, 0, :] | |
| sum_emb += seg_emb.float().sum(dim=0) | |
| count += seg_emb.shape[0] | |
| doc_emb = sum_emb / max(count, 1) | |
| logits = classifier(doc_emb.unsqueeze(0))[0] | |
| pred_id = int(torch.argmax(logits).item()) | |
| label = id2label[pred_id] | |
| info = LABEL_INFO.get(label) | |
| if not info: | |
| return "—" | |
| return f"Language: {info['language']}\nCentury: {info['century']}" | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=12, placeholder="Paste a document here..."), | |
| outputs=gr.Textbox(label="Century & Language"), | |
| title="Document Classifier", | |
| description="Paste a document and get the predicted historical range.", | |
| ) | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False) | |