Spaces:
Running
Running
File size: 15,912 Bytes
47869c2 097d4eb 47869c2 097d4eb 20d6878 47869c2 20d6878 47869c2 20d6878 47869c2 20d6878 47869c2 6330fc7 20d6878 47869c2 20d6878 097d4eb 47869c2 20d6878 47869c2 097d4eb 47869c2 20d6878 47869c2 20d6878 47869c2 097d4eb 47869c2 20d6878 47869c2 20d6878 6330fc7 20d6878 47869c2 20d6878 47869c2 097d4eb 47869c2 20d6878 47869c2 20d6878 47869c2 20d6878 47869c2 20d6878 47869c2 097d4eb 47869c2 097d4eb 47869c2 097d4eb 20d6878 47869c2 097d4eb 47869c2 097d4eb 47869c2 20d6878 47869c2 097d4eb 47869c2 097d4eb 47869c2 097d4eb 47869c2 097d4eb 20d6878 47869c2 20d6878 47869c2 20d6878 097d4eb 20d6878 47869c2 20d6878 47869c2 097d4eb 20d6878 47869c2 20d6878 47869c2 20d6878 47869c2 097d4eb 47869c2 097d4eb 47869c2 20d6878 097d4eb 20d6878 5659365 20d6878 5659365 20d6878 097d4eb 5659365 20d6878 47869c2 097d4eb 47869c2 097d4eb 47869c2 20d6878 47869c2 20d6878 47869c2 20d6878 47869c2 e9e2747 47869c2 df44061 20d6878 47869c2 20d6878 097d4eb 47869c2 20d6878 47869c2 20d6878 47869c2 20d6878 47869c2 20d6878 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 | #!/usr/bin/env python3
"""
app.py — BeRestoral
"""
import html
import json
import math
import re
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import uvicorn
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from transformers import AutoModelForMaskedLM, AutoTokenizer
app = FastAPI(title="BeRestoral")
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
device = torch.device("cpu")
MODEL_PATH_BPE = "MaximEremeev/RoFormer-slav"
MODEL_PATH_CHAR = "MaximEremeev/DualEmb-slav"
PROBE_DIR = Path("probes")
BIN_START = 1050
BIN_SIZE = 50
N_BINS = 9
BINS = [(BIN_START + i * BIN_SIZE, BIN_START + (i + 1) * BIN_SIZE - 1)
for i in range(N_BINS)]
BIN_MIDPOINTS = np.array([(lo + hi) / 2 for lo, hi in BINS])
BIN_LABELS = [f"{lo}–{hi}" for lo, hi in BINS]
CATEGORY_LABELS = ["letters", "records", "religious", "other"]
CATEGORY_LABELS_RU = ["письма", "деловые записи", "религиозные тексты", "другое"]
print("Loading BPE model (RoFormer)...")
tokenizer_bpe = AutoTokenizer.from_pretrained(MODEL_PATH_BPE, trust_remote_code=True)
tokenizer_bpe.add_special_tokens({"additional_special_tokens": ["[GAP]"]})
model_bpe = AutoModelForMaskedLM.from_pretrained(MODEL_PATH_BPE, trust_remote_code=True).to(device)
model_bpe.eval()
print("Loading char model (DualEmbLM)...")
from huggingface_hub import hf_hub_download
model_char = AutoModelForMaskedLM.from_pretrained(
MODEL_PATH_CHAR, trust_remote_code=True).to(device)
model_char.eval()
_char_vocab_path = hf_hub_download(repo_id=MODEL_PATH_CHAR, filename="char_vocab.json")
_word_vocab_path = hf_hub_download(repo_id=MODEL_PATH_CHAR, filename="word_vocab.json")
char_vocab = json.loads(Path(_char_vocab_path).read_text(encoding="utf-8"))
word_vocab = json.loads(Path(_word_vocab_path).read_text(encoding="utf-8"))
id_to_char = {v: k for k, v in char_vocab.items()}
EMBED_DIM = 512
class LinearProbe(nn.Module):
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(in_dim),
nn.Dropout(dropout),
nn.Linear(in_dim, out_dim),
)
def forward(self, x):
return self.net(x)
print("Loading probe classifiers...")
probe_category = LinearProbe(EMBED_DIM, len(CATEGORY_LABELS))
probe_date = LinearProbe(EMBED_DIM, N_BINS)
probe_category.load_state_dict(torch.load(
PROBE_DIR / "RoFormer_category_masked_probe.pth", map_location=device, weights_only=True))
probe_date.load_state_dict(torch.load(
PROBE_DIR / "RoFormer_date_masked_probe.pth", map_location=device, weights_only=True))
probe_category.eval()
probe_date.eval()
print("All models loaded.")
SPECIAL_RE = re.compile(r"(\[GAP\]|\[MASK\]|\[PAD\]|\[UNK\]|\[CLS\]|\[SEP\]|[+:·])")
def split_special(text: str) -> list[str]:
return [p for p in SPECIAL_RE.split(text) if p]
def align_char_to_word(text: str, char_v: dict, word_v: dict, max_len: int = 256):
c_unk = char_v["[UNK]"]; c_sep = char_v["[SEP]"]; c_cls = char_v["[CLS]"]
w_unk = word_v.get("[UNK_WORD]", 0)
input_ids, word_ids = [c_cls], [word_v.get("[CLS]", w_unk)]
for part in split_special(text.strip()):
if SPECIAL_RE.fullmatch(part):
input_ids.append(char_v.get(part, c_unk))
word_ids.append(word_v.get(part, w_unk))
continue
for chunk in re.split(r"(\s+)", part):
if not chunk: continue
if chunk.isspace():
for ch in chunk:
input_ids.append(char_v.get(ch, c_unk)); word_ids.append(w_unk)
else:
wid = word_v.get(chunk, w_unk)
for ch in chunk:
input_ids.append(char_v.get(ch, c_unk)); word_ids.append(wid)
input_ids.append(c_sep); word_ids.append(word_v.get("[SEP]", w_unk))
if len(input_ids) > max_len:
input_ids, word_ids = input_ids[:max_len], word_ids[:max_len]
input_ids[-1] = c_sep; word_ids[-1] = word_v.get("[SEP]", w_unk)
max_char_id = model_char.config.vocab_char_size - 1
max_word_id = model_char.config.vocab_word_size - 1
return {
"input_ids": [x if x <= max_char_id else c_unk for x in input_ids],
"word_ids": [x if x <= max_word_id else w_unk for x in word_ids],
}
def get_roformer_embedding(text: str) -> torch.Tensor:
"""Mean pooling over non-padding tokens from RoFormer encoder.
text should already contain BPE mask tokens where lacunae are."""
clean = re.sub(r"\s+", " ", text).strip()
enc = tokenizer_bpe(clean, return_tensors="pt", truncation=True,
max_length=512, return_attention_mask=True)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
out = model_bpe(**enc, output_hidden_states=True)
hidden = out.hidden_states[-1]
mask = enc["attention_mask"].unsqueeze(-1).float()
emb = (hidden * mask).sum(dim=1) / mask.sum(dim=1)
return emb.squeeze(0)
def classify(text: str) -> dict:
emb = get_roformer_embedding(text).unsqueeze(0)
with torch.no_grad():
cat_logits = probe_category(emb)[0]
date_logits = probe_date(emb)[0]
cat_probs = torch.softmax(cat_logits, dim=-1).cpu().numpy().tolist()
date_probs = torch.softmax(date_logits, dim=-1).cpu().numpy().tolist()
best_cat = int(np.argmax(cat_probs))
pred_year = float(np.dot(date_probs, BIN_MIDPOINTS))
return {
"category": CATEGORY_LABELS[best_cat],
"category_ru": CATEGORY_LABELS_RU[best_cat],
"category_probs": {CATEGORY_LABELS[i]: round(p, 4) for i, p in enumerate(cat_probs)},
"pred_year": round(pred_year),
"date_probs": [round(p, 4) for p in date_probs],
"bin_labels": BIN_LABELS,
}
def generate_sequential(text: str, is_char: bool,
top_k: int = 5, temperature: float = 1.0) -> dict:
if is_char:
encoded = align_char_to_word(text, char_vocab, word_vocab)
input_ids = torch.tensor(encoded["input_ids"]).to(device)
word_ids = torch.tensor(encoded["word_ids"]).to(device)
mask_token_id = char_vocab["[MASK]"]
mask_str = "[MASK]"
model = model_char
else:
inputs = tokenizer_bpe(text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"][0]
word_ids = None
mask_token_id = tokenizer_bpe.mask_token_id
mask_str = tokenizer_bpe.mask_token
model = model_bpe
original_mask_indices = torch.where(input_ids == mask_token_id)[0].tolist()
if not original_mask_indices:
return {"variants": [], "steps": []}
current_states = [{"input_ids": input_ids.clone(), "log_prob": 0.0,
"inserted_tokens": {}}]
unfilled_masks = original_mask_indices.copy()
steps = []
# For char_pos: track how many masks have been filled so far
# to compute offset correctly
masks_filled_count = 0
mask_str_len = len(mask_str)
with torch.no_grad():
while unfilled_masks:
probe_ids = current_states[0]["input_ids"].unsqueeze(0).to(device)
if is_char:
outputs = model(input_ids=probe_ids, word_ids=word_ids.unsqueeze(0))
else:
outputs = model(input_ids=probe_ids)
logits = outputs.logits[0]
best_mask_idx, highest_prob = None, -1.0
for m_idx in unfilled_masks:
probs = torch.softmax(logits[m_idx] / max(0.01, temperature), dim=-1)
p = torch.max(probs).item()
if p > highest_prob:
highest_prob, best_mask_idx = p, m_idx
unfilled_masks.remove(best_mask_idx)
steps.append({"pos": best_mask_idx, "confidence": round(highest_prob * 100, 1)})
batch_ids = torch.stack([s["input_ids"] for s in current_states]).to(device)
if is_char:
batch_wids = word_ids.unsqueeze(0).expand(len(current_states), -1).to(device)
outputs = model(input_ids=batch_ids, word_ids=batch_wids)
else:
outputs = model(input_ids=batch_ids)
mask_logits = outputs.logits[:, best_mask_idx, :]
probs_k = torch.softmax(mask_logits / max(0.01, temperature), dim=-1)
top_probs, top_ids = torch.topk(probs_k, top_k, dim=-1)
new_candidates = []
for si, state in enumerate(current_states):
for i in range(top_k):
tid = top_ids[si, i].item()
prob = top_probs[si, i].item()
new_ids = state["input_ids"].clone()
new_ids[best_mask_idx] = tid
new_tok = dict(state["inserted_tokens"])
new_tok[best_mask_idx] = tid
new_candidates.append({
"input_ids": new_ids,
"log_prob": state["log_prob"] + math.log(max(prob, 1e-9)),
"inserted_tokens": new_tok,
})
current_states = sorted(new_candidates,
key=lambda x: x["log_prob"], reverse=True)[:top_k]
_best_id = current_states[0]["inserted_tokens"].get(best_mask_idx)
if is_char:
_pred = id_to_char.get(_best_id, "")
else:
_pred = tokenizer_bpe.decode(
[_best_id], clean_up_tokenization_spaces=False
).replace("Ġ", "").replace("##", "").strip() if _best_id else ""
# Build partial sentence using direct token index addressing
best_state = current_states[0]
if is_char:
# input_ids: [CLS, tok1, tok2, ..., SEP] — skip CLS(0) and SEP(-1)
current_tokens = [id_to_char.get(tid.item(), "")
for tid in input_ids[1:-1]]
# Fill in predictions
for op in original_mask_indices:
fid = best_state["inserted_tokens"].get(op)
idx_in_tokens = op - 1 # offset for [CLS]
if fid is not None and 0 <= idx_in_tokens < len(current_tokens):
current_tokens[idx_in_tokens] = id_to_char.get(fid, "")
# Build marked string
target_idx = best_mask_idx - 1
parts = []
for i, tok in enumerate(current_tokens):
if tok in ("[MASK]", "[GAP]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"):
parts.append("[MASK]" if tok == "[MASK]" else tok)
elif i == target_idx:
parts.append(f"[[R]]{tok}[[/R]]")
else:
parts.append(tok)
_partial_marked = "".join(parts)
else:
# BPE: input_ids includes special tokens, decode each
current_tokens = [tokenizer_bpe.decode([tid.item()],
clean_up_tokenization_spaces=False)
for tid in input_ids]
for op in original_mask_indices:
fid = best_state["inserted_tokens"].get(op)
if fid is not None and 0 <= op < len(current_tokens):
current_tokens[op] = tokenizer_bpe.decode(
[fid], clean_up_tokenization_spaces=False)
target_idx = best_mask_idx
parts = []
for i, tok in enumerate(current_tokens):
clean = tok.replace("Ġ", " ").replace("##", "")
if tok == tokenizer_bpe.mask_token:
parts.append("[MASK]")
elif i == target_idx:
parts.append(f"[[R]]{clean}[[/R]]")
elif tok in (tokenizer_bpe.cls_token, tokenizer_bpe.sep_token,
tokenizer_bpe.pad_token):
pass # skip special tokens
else:
parts.append(clean)
_partial_marked = re.sub(r" +", " ", "".join(parts)).strip()
steps[-1]["token"] = _pred
steps[-1]["partial_sentence"] = _partial_marked
variants = []
escaped_mask = html.escape(mask_str)
for state in current_states:
ordered_ids = [state["inserted_tokens"][i] for i in original_mask_indices]
full_sentence = html.escape(text)
if is_char:
inserted = "".join(id_to_char.get(t, "") for t in ordered_ids).strip()
for tid in ordered_ids:
ch = id_to_char.get(tid, "")
tok = " " if ch == " " else html.escape(ch)
full_sentence = full_sentence.replace(
escaped_mask, f'<span class="highlight-restored">{tok}</span>', 1)
else:
inserted = tokenizer_bpe.decode(ordered_ids,
clean_up_tokenization_spaces=True).strip()
for tid in ordered_ids:
tok = html.escape(tokenizer_bpe.decode([tid])
.replace("Ġ","").replace("##","").replace(" ",""))
full_sentence = full_sentence.replace(
escaped_mask, f'<span class="highlight-restored">{tok}</span>', 1)
full_sentence = re.sub(r"\s+", " ", full_sentence.strip())
variants.append({
"word": inserted or "...",
"score": round(math.exp(state["log_prob"]) * 100, 2),
"full_sentence": full_sentence,
"raw_log_prob": state["log_prob"],
})
return {"variants": variants, "steps": steps}
class RestoreRequest(BaseModel):
text: str
mode: str = "char"
top_k: int = 5
temperature: float = 1.0
@app.get("/")
async def read_root(request: Request):
return templates.TemplateResponse(request=request, name="index.html")
@app.post("/api/restore")
async def restore_text(req: RestoreRequest) -> Dict[str, Any]:
try:
is_char = req.mode == "char"
mask = "[MASK]" if is_char else tokenizer_bpe.mask_token
text = req.text.replace("#", "[GAP]")
# Lowering everything except special tokens
parts = SPECIAL_RE.split(text)
text = "".join(p if SPECIAL_RE.fullmatch(p) else p.lower() for p in parts if p)
n_gaps = text.count("-") + text.count("[GAP]")
# Classification — always use BPE mask regardless of mode
bpe_mask = tokenizer_bpe.mask_token
masked_for_classify = re.sub(r"-", bpe_mask, text)
masked_for_classify = re.sub(r" +", " ", masked_for_classify).strip()
classification = classify(masked_for_classify)
# Restoration
query = re.sub(r" +", " ", text.replace("-", mask)).strip()
_res = generate_sequential(query, is_char, req.top_k, req.temperature)
return {
"status": "success",
"results": [_res["variants"]],
"steps": [_res["steps"]],
"n_gaps": n_gaps,
"classification": classification,
}
except Exception as e:
import traceback
return {"status": "error", "message": str(e),
"traceback": traceback.format_exc()}
if __name__ == "__main__":
import os
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
|