Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| app.py — BeRestoral | |
| """ | |
| import html | |
| import json | |
| import math | |
| import re | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import uvicorn | |
| from fastapi import FastAPI, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer | |
| app = FastAPI(title="BeRestoral") | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| device = torch.device("cpu") | |
| MODEL_PATH_BPE = "MaximEremeev/RoFormer-slav" | |
| MODEL_PATH_CHAR = "MaximEremeev/DualEmb-slav" | |
| PROBE_DIR = Path("probes") | |
| BIN_START = 1050 | |
| BIN_SIZE = 50 | |
| N_BINS = 9 | |
| BINS = [(BIN_START + i * BIN_SIZE, BIN_START + (i + 1) * BIN_SIZE - 1) | |
| for i in range(N_BINS)] | |
| BIN_MIDPOINTS = np.array([(lo + hi) / 2 for lo, hi in BINS]) | |
| BIN_LABELS = [f"{lo}–{hi}" for lo, hi in BINS] | |
| CATEGORY_LABELS = ["letters", "records", "religious", "other"] | |
| CATEGORY_LABELS_RU = ["письма", "деловые записи", "религиозные тексты", "другое"] | |
| print("Loading BPE model (RoFormer)...") | |
| tokenizer_bpe = AutoTokenizer.from_pretrained(MODEL_PATH_BPE, trust_remote_code=True) | |
| tokenizer_bpe.add_special_tokens({"additional_special_tokens": ["[GAP]"]}) | |
| model_bpe = AutoModelForMaskedLM.from_pretrained(MODEL_PATH_BPE, trust_remote_code=True).to(device) | |
| model_bpe.eval() | |
| print("Loading char model (DualEmbLM)...") | |
| from huggingface_hub import hf_hub_download | |
| model_char = AutoModelForMaskedLM.from_pretrained( | |
| MODEL_PATH_CHAR, trust_remote_code=True).to(device) | |
| model_char.eval() | |
| _char_vocab_path = hf_hub_download(repo_id=MODEL_PATH_CHAR, filename="char_vocab.json") | |
| _word_vocab_path = hf_hub_download(repo_id=MODEL_PATH_CHAR, filename="word_vocab.json") | |
| char_vocab = json.loads(Path(_char_vocab_path).read_text(encoding="utf-8")) | |
| word_vocab = json.loads(Path(_word_vocab_path).read_text(encoding="utf-8")) | |
| id_to_char = {v: k for k, v in char_vocab.items()} | |
| EMBED_DIM = 512 | |
| class LinearProbe(nn.Module): | |
| def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(in_dim), | |
| nn.Dropout(dropout), | |
| nn.Linear(in_dim, out_dim), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| print("Loading probe classifiers...") | |
| probe_category = LinearProbe(EMBED_DIM, len(CATEGORY_LABELS)) | |
| probe_date = LinearProbe(EMBED_DIM, N_BINS) | |
| probe_category.load_state_dict(torch.load( | |
| PROBE_DIR / "RoFormer_category_masked_probe.pth", map_location=device, weights_only=True)) | |
| probe_date.load_state_dict(torch.load( | |
| PROBE_DIR / "RoFormer_date_masked_probe.pth", map_location=device, weights_only=True)) | |
| probe_category.eval() | |
| probe_date.eval() | |
| print("All models loaded.") | |
| SPECIAL_RE = re.compile(r"(\[GAP\]|\[MASK\]|\[PAD\]|\[UNK\]|\[CLS\]|\[SEP\]|[+:·])") | |
| def split_special(text: str) -> list[str]: | |
| return [p for p in SPECIAL_RE.split(text) if p] | |
| def align_char_to_word(text: str, char_v: dict, word_v: dict, max_len: int = 256): | |
| c_unk = char_v["[UNK]"]; c_sep = char_v["[SEP]"]; c_cls = char_v["[CLS]"] | |
| w_unk = word_v.get("[UNK_WORD]", 0) | |
| input_ids, word_ids = [c_cls], [word_v.get("[CLS]", w_unk)] | |
| for part in split_special(text.strip()): | |
| if SPECIAL_RE.fullmatch(part): | |
| input_ids.append(char_v.get(part, c_unk)) | |
| word_ids.append(word_v.get(part, w_unk)) | |
| continue | |
| for chunk in re.split(r"(\s+)", part): | |
| if not chunk: continue | |
| if chunk.isspace(): | |
| for ch in chunk: | |
| input_ids.append(char_v.get(ch, c_unk)); word_ids.append(w_unk) | |
| else: | |
| wid = word_v.get(chunk, w_unk) | |
| for ch in chunk: | |
| input_ids.append(char_v.get(ch, c_unk)); word_ids.append(wid) | |
| input_ids.append(c_sep); word_ids.append(word_v.get("[SEP]", w_unk)) | |
| if len(input_ids) > max_len: | |
| input_ids, word_ids = input_ids[:max_len], word_ids[:max_len] | |
| input_ids[-1] = c_sep; word_ids[-1] = word_v.get("[SEP]", w_unk) | |
| max_char_id = model_char.config.vocab_char_size - 1 | |
| max_word_id = model_char.config.vocab_word_size - 1 | |
| return { | |
| "input_ids": [x if x <= max_char_id else c_unk for x in input_ids], | |
| "word_ids": [x if x <= max_word_id else w_unk for x in word_ids], | |
| } | |
| def get_roformer_embedding(text: str) -> torch.Tensor: | |
| """Mean pooling over non-padding tokens from RoFormer encoder. | |
| text should already contain BPE mask tokens where lacunae are.""" | |
| clean = re.sub(r"\s+", " ", text).strip() | |
| enc = tokenizer_bpe(clean, return_tensors="pt", truncation=True, | |
| max_length=512, return_attention_mask=True) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| out = model_bpe(**enc, output_hidden_states=True) | |
| hidden = out.hidden_states[-1] | |
| mask = enc["attention_mask"].unsqueeze(-1).float() | |
| emb = (hidden * mask).sum(dim=1) / mask.sum(dim=1) | |
| return emb.squeeze(0) | |
| def classify(text: str) -> dict: | |
| emb = get_roformer_embedding(text).unsqueeze(0) | |
| with torch.no_grad(): | |
| cat_logits = probe_category(emb)[0] | |
| date_logits = probe_date(emb)[0] | |
| cat_probs = torch.softmax(cat_logits, dim=-1).cpu().numpy().tolist() | |
| date_probs = torch.softmax(date_logits, dim=-1).cpu().numpy().tolist() | |
| best_cat = int(np.argmax(cat_probs)) | |
| pred_year = float(np.dot(date_probs, BIN_MIDPOINTS)) | |
| return { | |
| "category": CATEGORY_LABELS[best_cat], | |
| "category_ru": CATEGORY_LABELS_RU[best_cat], | |
| "category_probs": {CATEGORY_LABELS[i]: round(p, 4) for i, p in enumerate(cat_probs)}, | |
| "pred_year": round(pred_year), | |
| "date_probs": [round(p, 4) for p in date_probs], | |
| "bin_labels": BIN_LABELS, | |
| } | |
| def generate_sequential(text: str, is_char: bool, | |
| top_k: int = 5, temperature: float = 1.0) -> dict: | |
| if is_char: | |
| encoded = align_char_to_word(text, char_vocab, word_vocab) | |
| input_ids = torch.tensor(encoded["input_ids"]).to(device) | |
| word_ids = torch.tensor(encoded["word_ids"]).to(device) | |
| mask_token_id = char_vocab["[MASK]"] | |
| mask_str = "[MASK]" | |
| model = model_char | |
| else: | |
| inputs = tokenizer_bpe(text, return_tensors="pt").to(device) | |
| input_ids = inputs["input_ids"][0] | |
| word_ids = None | |
| mask_token_id = tokenizer_bpe.mask_token_id | |
| mask_str = tokenizer_bpe.mask_token | |
| model = model_bpe | |
| original_mask_indices = torch.where(input_ids == mask_token_id)[0].tolist() | |
| if not original_mask_indices: | |
| return {"variants": [], "steps": []} | |
| current_states = [{"input_ids": input_ids.clone(), "log_prob": 0.0, | |
| "inserted_tokens": {}}] | |
| unfilled_masks = original_mask_indices.copy() | |
| steps = [] | |
| # For char_pos: track how many masks have been filled so far | |
| # to compute offset correctly | |
| masks_filled_count = 0 | |
| mask_str_len = len(mask_str) | |
| with torch.no_grad(): | |
| while unfilled_masks: | |
| probe_ids = current_states[0]["input_ids"].unsqueeze(0).to(device) | |
| if is_char: | |
| outputs = model(input_ids=probe_ids, word_ids=word_ids.unsqueeze(0)) | |
| else: | |
| outputs = model(input_ids=probe_ids) | |
| logits = outputs.logits[0] | |
| best_mask_idx, highest_prob = None, -1.0 | |
| for m_idx in unfilled_masks: | |
| probs = torch.softmax(logits[m_idx] / max(0.01, temperature), dim=-1) | |
| p = torch.max(probs).item() | |
| if p > highest_prob: | |
| highest_prob, best_mask_idx = p, m_idx | |
| unfilled_masks.remove(best_mask_idx) | |
| steps.append({"pos": best_mask_idx, "confidence": round(highest_prob * 100, 1)}) | |
| batch_ids = torch.stack([s["input_ids"] for s in current_states]).to(device) | |
| if is_char: | |
| batch_wids = word_ids.unsqueeze(0).expand(len(current_states), -1).to(device) | |
| outputs = model(input_ids=batch_ids, word_ids=batch_wids) | |
| else: | |
| outputs = model(input_ids=batch_ids) | |
| mask_logits = outputs.logits[:, best_mask_idx, :] | |
| probs_k = torch.softmax(mask_logits / max(0.01, temperature), dim=-1) | |
| top_probs, top_ids = torch.topk(probs_k, top_k, dim=-1) | |
| new_candidates = [] | |
| for si, state in enumerate(current_states): | |
| for i in range(top_k): | |
| tid = top_ids[si, i].item() | |
| prob = top_probs[si, i].item() | |
| new_ids = state["input_ids"].clone() | |
| new_ids[best_mask_idx] = tid | |
| new_tok = dict(state["inserted_tokens"]) | |
| new_tok[best_mask_idx] = tid | |
| new_candidates.append({ | |
| "input_ids": new_ids, | |
| "log_prob": state["log_prob"] + math.log(max(prob, 1e-9)), | |
| "inserted_tokens": new_tok, | |
| }) | |
| current_states = sorted(new_candidates, | |
| key=lambda x: x["log_prob"], reverse=True)[:top_k] | |
| _best_id = current_states[0]["inserted_tokens"].get(best_mask_idx) | |
| if is_char: | |
| _pred = id_to_char.get(_best_id, "") | |
| else: | |
| _pred = tokenizer_bpe.decode( | |
| [_best_id], clean_up_tokenization_spaces=False | |
| ).replace("Ġ", "").replace("##", "").strip() if _best_id else "" | |
| # Build partial sentence using direct token index addressing | |
| best_state = current_states[0] | |
| if is_char: | |
| # input_ids: [CLS, tok1, tok2, ..., SEP] — skip CLS(0) and SEP(-1) | |
| current_tokens = [id_to_char.get(tid.item(), "") | |
| for tid in input_ids[1:-1]] | |
| # Fill in predictions | |
| for op in original_mask_indices: | |
| fid = best_state["inserted_tokens"].get(op) | |
| idx_in_tokens = op - 1 # offset for [CLS] | |
| if fid is not None and 0 <= idx_in_tokens < len(current_tokens): | |
| current_tokens[idx_in_tokens] = id_to_char.get(fid, "") | |
| # Build marked string | |
| target_idx = best_mask_idx - 1 | |
| parts = [] | |
| for i, tok in enumerate(current_tokens): | |
| if tok in ("[MASK]", "[GAP]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"): | |
| parts.append("[MASK]" if tok == "[MASK]" else tok) | |
| elif i == target_idx: | |
| parts.append(f"[[R]]{tok}[[/R]]") | |
| else: | |
| parts.append(tok) | |
| _partial_marked = "".join(parts) | |
| else: | |
| # BPE: input_ids includes special tokens, decode each | |
| current_tokens = [tokenizer_bpe.decode([tid.item()], | |
| clean_up_tokenization_spaces=False) | |
| for tid in input_ids] | |
| for op in original_mask_indices: | |
| fid = best_state["inserted_tokens"].get(op) | |
| if fid is not None and 0 <= op < len(current_tokens): | |
| current_tokens[op] = tokenizer_bpe.decode( | |
| [fid], clean_up_tokenization_spaces=False) | |
| target_idx = best_mask_idx | |
| parts = [] | |
| for i, tok in enumerate(current_tokens): | |
| clean = tok.replace("Ġ", " ").replace("##", "") | |
| if tok == tokenizer_bpe.mask_token: | |
| parts.append("[MASK]") | |
| elif i == target_idx: | |
| parts.append(f"[[R]]{clean}[[/R]]") | |
| elif tok in (tokenizer_bpe.cls_token, tokenizer_bpe.sep_token, | |
| tokenizer_bpe.pad_token): | |
| pass # skip special tokens | |
| else: | |
| parts.append(clean) | |
| _partial_marked = re.sub(r" +", " ", "".join(parts)).strip() | |
| steps[-1]["token"] = _pred | |
| steps[-1]["partial_sentence"] = _partial_marked | |
| variants = [] | |
| escaped_mask = html.escape(mask_str) | |
| for state in current_states: | |
| ordered_ids = [state["inserted_tokens"][i] for i in original_mask_indices] | |
| full_sentence = html.escape(text) | |
| if is_char: | |
| inserted = "".join(id_to_char.get(t, "") for t in ordered_ids).strip() | |
| for tid in ordered_ids: | |
| ch = id_to_char.get(tid, "") | |
| tok = " " if ch == " " else html.escape(ch) | |
| full_sentence = full_sentence.replace( | |
| escaped_mask, f'<span class="highlight-restored">{tok}</span>', 1) | |
| else: | |
| inserted = tokenizer_bpe.decode(ordered_ids, | |
| clean_up_tokenization_spaces=True).strip() | |
| for tid in ordered_ids: | |
| tok = html.escape(tokenizer_bpe.decode([tid]) | |
| .replace("Ġ","").replace("##","").replace(" ","")) | |
| full_sentence = full_sentence.replace( | |
| escaped_mask, f'<span class="highlight-restored">{tok}</span>', 1) | |
| full_sentence = re.sub(r"\s+", " ", full_sentence.strip()) | |
| variants.append({ | |
| "word": inserted or "...", | |
| "score": round(math.exp(state["log_prob"]) * 100, 2), | |
| "full_sentence": full_sentence, | |
| "raw_log_prob": state["log_prob"], | |
| }) | |
| return {"variants": variants, "steps": steps} | |
| class RestoreRequest(BaseModel): | |
| text: str | |
| mode: str = "char" | |
| top_k: int = 5 | |
| temperature: float = 1.0 | |
| async def read_root(request: Request): | |
| return templates.TemplateResponse(request=request, name="index.html") | |
| async def restore_text(req: RestoreRequest) -> Dict[str, Any]: | |
| try: | |
| is_char = req.mode == "char" | |
| mask = "[MASK]" if is_char else tokenizer_bpe.mask_token | |
| text = req.text.replace("#", "[GAP]") | |
| # Lowering everything except special tokens | |
| parts = SPECIAL_RE.split(text) | |
| text = "".join(p if SPECIAL_RE.fullmatch(p) else p.lower() for p in parts if p) | |
| n_gaps = text.count("-") + text.count("[GAP]") | |
| # Classification — always use BPE mask regardless of mode | |
| bpe_mask = tokenizer_bpe.mask_token | |
| masked_for_classify = re.sub(r"-", bpe_mask, text) | |
| masked_for_classify = re.sub(r" +", " ", masked_for_classify).strip() | |
| classification = classify(masked_for_classify) | |
| # Restoration | |
| query = re.sub(r" +", " ", text.replace("-", mask)).strip() | |
| _res = generate_sequential(query, is_char, req.top_k, req.temperature) | |
| return { | |
| "status": "success", | |
| "results": [_res["variants"]], | |
| "steps": [_res["steps"]], | |
| "n_gaps": n_gaps, | |
| "classification": classification, | |
| } | |
| except Exception as e: | |
| import traceback | |
| return {"status": "error", "message": str(e), | |
| "traceback": traceback.format_exc()} | |
| if __name__ == "__main__": | |
| import os | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |