Aleks1706's picture
Create app.py
83f11fd verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
import re
import ast
app = FastAPI()
# --- Model IDs ---
LEMMA_MODEL_ID = "ICEF-NLP/bcms-bertic-comtext-sr-legal-lemma-ekavica"
MSD_MODEL_ID = "ICEF-NLP/bcms-bertic-comtext-sr-legal-msd-ekavica"
NER_MODEL_ID = "ICEF-NLP/bcms-bertic-comtext-sr-legal-ner-ekavica"
# --- Load models once at startup ---
print("Loading tokenizer...")
hf_tokenizer = AutoTokenizer.from_pretrained(LEMMA_MODEL_ID)
print("Loading lemma model...")
lemma_model = AutoModelForTokenClassification.from_pretrained(LEMMA_MODEL_ID)
lemma_model.eval()
lemma_id2label = lemma_model.config.id2label
print("Loading MSD model...")
msd_model = AutoModelForTokenClassification.from_pretrained(MSD_MODEL_ID)
msd_model.eval()
msd_id2label = msd_model.config.id2label
print("Loading NER model...")
ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_ID)
ner_model.eval()
ner_id2label = ner_model.config.id2label
print("All models loaded.")
# --- Legal abbreviations ---
LEGAL_ABBREVS = {
'čl', 'st', 'br', 'sl', 'dr', 'mr', 'prof', 'ing', 'tač',
'str', 'god', 'par', 'al', 'sek', 'tzv', 'itd', 'idr',
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
'jul', 'avg', 'sep', 'okt', 'nov', 'dec',
'stav', 'red', 'nn', 'rs', 'rh', 'bih',
'ul', 'bb', 'sp', 'tr', 'tel', 'fax', 'www'
}
# Company suffixes that should never be split — all variants
COMPANY_SUFFIXES = [
r'd\.o\.o\.?', r'D\.O\.O\.?',
r'a\.d\.?', r'A\.D\.?',
r'd\.d\.?', r'D\.D\.?',
r'j\.p\.?', r'J\.P\.?',
r'k\.d\.?', r'K\.D\.?',
r'o\.d\.?', r'O\.D\.?',
r'doo', r'DOO',
r'ad', r'AD',
]
def protect_text(text):
"""Replace all dots/commas that are NOT boundaries with placeholders."""
# Protect URLs
text = re.sub(
r'https?://[^\s]+',
lambda m: m.group(0).replace('.', '\x00').replace('/', '\x03').replace(':', '\x04'),
text
)
text = re.sub(
r'www\.[^\s]+',
lambda m: m.group(0).replace('.', '\x00'),
text
)
# Protect emails
text = re.sub(
r'[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}',
lambda m: m.group(0).replace('.', '\x00').replace('@', '\x02'),
text
)
# Protect dates like 12.03.2024. or 12.03.2024
text = re.sub(
r'\b\d{1,2}\.\d{1,2}\.\d{2,4}\.?',
lambda m: m.group(0).replace('.', '\x00'),
text
)
# Protect money/amount decimals and thousands separators
# e.g. 1.500.000,00 or 145.000,00 or 1.500
# dot between digits only when it looks like thousands separator (groups of 3)
text = re.sub(
r'\b(\d{1,3})(\.\d{3})+([,\x01]\d+)?\b',
lambda m: m.group(0).replace('.', '\x00').replace(',', '\x01'),
text
)
# comma between digits = decimal separator in amounts
text = re.sub(r'(\d),(\d)', lambda m: m.group(1) + '\x01' + m.group(2), text)
# remaining dot between digits (e.g. simple decimals)
text = re.sub(r'(\d)\.(\d)', lambda m: m.group(1) + '\x00' + m.group(2), text)
# Protect company suffixes — all variants
for suffix in COMPANY_SUFFIXES:
text = re.sub(
suffix,
lambda m: m.group(0).replace('.', '\x00'),
text,
flags=re.IGNORECASE
)
# Protect single uppercase initials like A. B. Ov.
text = re.sub(
r'(?<!\w)([A-ZŠĐČĆŽ])\.',
lambda m: m.group(1) + '\x00',
text
)
# Protect chained letter.letter patterns like d.o.o. a.d. j.p.
for _ in range(5):
text = re.sub(
r'(?<![^\s(„"\x00])([a-zA-Zšđč枊ĐČĆŽ]{1,3})\.([a-zA-Zšđč枊ĐČĆŽ])',
lambda m: m.group(1) + '\x00' + m.group(2),
text
)
# Protect legal abbreviations
for abbrev in sorted(LEGAL_ABBREVS, key=len, reverse=True):
text = re.sub(
re.escape(abbrev) + r'\.',
lambda m: m.group(0).replace('.', '\x00'),
text,
flags=re.IGNORECASE
)
return text
def restore_text(text):
return (text
.replace('\x00', '.')
.replace('\x01', ',')
.replace('\x02', '@')
.replace('\x03', '/')
.replace('\x04', ':')
)
def fix_quotes(tokens):
"""
Ensure Serbian quotes are balanced and paired.
„ is opening, " is closing.
If they end up separated from their content, rejoin them.
Strategy: „ attaches to next token, " attaches to previous token.
"""
result = []
i = 0
while i < len(tokens):
tok = tokens[i]
# Opening quote „ — attach to next token if next is not also a quote
if tok in ('„', '"') and i + 1 < len(tokens) and tokens[i + 1] not in ('"', '„', '"'):
result.append(tok + tokens[i + 1])
i += 2
# Closing quote " — attach to previous token
elif tok in ('"', '"') and result and result[-1] not in ('„', '"'):
result[-1] = result[-1] + tok
i += 1
else:
result.append(tok)
i += 1
return result
def split_sentences(text):
text = re.sub(r'\r\n', '\n', text)
text = re.sub(r'[ \t]+', ' ', text)
# Every \n is a hard sentence boundary
raw_segments = text.split('\n')
final_sentences = []
for segment in raw_segments:
segment = segment.strip()
if not segment:
continue
protected = protect_text(segment)
tokens = protected.split(' ')
current_sentence = []
for i, token in enumerate(tokens):
current_sentence.append(token)
if not token:
continue
ends_with_other = token.endswith('!') or token.endswith('?')
ends_with_period = token.endswith('.')
if ends_with_other:
final_sentences.append(restore_text(' '.join(current_sentence).strip()))
current_sentence = []
continue
if ends_with_period:
# Ordinal number — only split if next is uppercase
if re.match(r'^\d+\.$', token):
if i + 1 < len(tokens):
next_tok = tokens[i + 1].lstrip('"\'„\x00(')
if next_tok and (next_tok[0].islower() or next_tok[0].isdigit()):
continue
# Split if next token starts uppercase
if i + 1 < len(tokens):
next_tok = tokens[i + 1].lstrip('"\'„\x00(')
if next_tok and next_tok[0].isupper():
final_sentences.append(restore_text(' '.join(current_sentence).strip()))
current_sentence = []
else:
final_sentences.append(restore_text(' '.join(current_sentence).strip()))
current_sentence = []
if current_sentence:
leftover = restore_text(' '.join(current_sentence).strip())
if leftover:
final_sentences.append(leftover)
return [s for s in final_sentences if s.strip()]
def tokenize_sentence(sentence):
# Protect everything
sentence = protect_text(sentence)
# Split on punctuation — comma is a token separator but NOT between digits
# (digit commas already protected as \x01)
sentence = re.sub(r"([;:!()\[\]{}\u00ab\u00bb\u2013\u2014])", r' \1 ', sentence)
# Comma is a token separator (but NOT a sentence separator)
# Only split commas that are NOT between digits (those are already \x01)
sentence = re.sub(r'(?<!\d),(?!\d)', r' , ', sentence)
sentence = re.sub(r'(?<=\d),(?!\d)', r' , ', sentence)
sentence = re.sub(r'(?<!\d),(?=\d)', r' , ', sentence)
# Split remaining dots
sentence = re.sub(r'\.', r' . ', sentence)
# Handle quotes — detach first so fix_quotes can work
sentence = re.sub(r'„', ' „ ', sentence)
sentence = re.sub(r'["""]', ' " ', sentence)
sentence = re.sub(r'\s+', ' ', sentence).strip()
# Restore
sentence = restore_text(sentence)
# Split into tokens
tokens = [t for t in sentence.split(' ') if t.strip()]
# Fix quote pairing
tokens = fix_quotes(tokens)
return tokens
def _predict_chunk(model, id2label, words):
encoding = hf_tokenizer(
words,
is_split_into_words=True,
return_tensors="pt",
truncation=True,
max_length=512,
padding=False
)
with torch.no_grad():
outputs = model(**encoding)
predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
word_ids = encoding.word_ids()
word_labels = {}
for idx, word_id in enumerate(word_ids):
if word_id is None:
continue
if word_id not in word_labels:
word_labels[word_id] = id2label[predictions[idx]]
return [word_labels.get(i, "O") for i in range(len(words))]
def predict_labels(model, id2label, words):
MAX_WORDS = 300
OVERLAP = 30
if len(words) <= MAX_WORDS:
return _predict_chunk(model, id2label, words)
labels = [None] * len(words)
start = 0
while start < len(words):
end = min(start + MAX_WORDS, len(words))
chunk = words[start:end]
chunk_labels = _predict_chunk(model, id2label, chunk)
for i, label in enumerate(chunk_labels):
global_idx = start + i
if labels[global_idx] is None:
labels[global_idx] = label
if end == len(words):
break
start += MAX_WORDS - OVERLAP
return [l if l is not None else "O" for l in labels]
def apply_lemma_transform(token, label):
try:
transformation = ast.literal_eval(label)
del_end, app_end, del_start, app_start = transformation
result = token
if del_end > 0:
result = result[:-del_end]
result = result + app_end
if del_start > 0:
result = result[del_start:]
result = app_start + result
return result
except Exception:
return token
def build_conllu(doc_id, sentences_tokens, sentences_lemmas, sentences_msds, sentences_ners):
lines = [f"# newdoc id = {doc_id}"]
for sent_idx, (tokens, lemmas, msds, ners) in enumerate(
zip(sentences_tokens, sentences_lemmas, sentences_msds, sentences_ners)
):
lines.append(f"# sent_id = {doc_id}.{sent_idx + 1}")
lines.append(f"# text = {' '.join(tokens)}")
for i, (token, lemma, msd, ner) in enumerate(zip(tokens, lemmas, msds, ners)):
lines.append(f"{i + 1}\t{token}\t{lemma}\t{msd}\t{ner}")
lines.append("")
return "\n".join(lines) + "\n"
# --- API ---
class ProcessRequest(BaseModel):
text: str
doc_id: str = "doc_1"
class ProcessResponse(BaseModel):
conllu: str
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/process", response_model=ProcessResponse)
def process(request: ProcessRequest):
try:
sentences = split_sentences(request.text)
all_tokens, all_lemmas, all_msds, all_ners = [], [], [], []
for idx, sentence in enumerate(sentences):
tokens = tokenize_sentence(sentence)
if not tokens:
continue
if len(tokens) > 1000:
print(f"Skipping sentence {idx} — too long ({len(tokens)} tokens)")
continue
try:
lemma_labels = predict_labels(lemma_model, lemma_id2label, tokens)
msd_labels = predict_labels(msd_model, msd_id2label, tokens)
ner_labels = predict_labels(ner_model, ner_id2label, tokens)
except Exception as e:
print(f"ERROR on sentence {idx}: {sentence[:80]}")
print(f"Tokens ({len(tokens)}): {tokens[:20]}")
print(f"Error: {e}")
continue
lemmas = [apply_lemma_transform(tok, lbl) for tok, lbl in zip(tokens, lemma_labels)]
all_tokens.append(tokens)
all_lemmas.append(lemmas)
all_msds.append(msd_labels)
all_ners.append(ner_labels)
if not all_tokens:
raise ValueError("No sentences could be processed")
conllu = build_conllu(request.doc_id, all_tokens, all_lemmas, all_msds, all_ners)
return ProcessResponse(conllu=conllu)
except Exception as e:
import traceback
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))