Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| import torch | |
| import re | |
| import ast | |
| app = FastAPI() | |
| # --- Model IDs --- | |
| LEMMA_MODEL_ID = "ICEF-NLP/bcms-bertic-comtext-sr-legal-lemma-ekavica" | |
| MSD_MODEL_ID = "ICEF-NLP/bcms-bertic-comtext-sr-legal-msd-ekavica" | |
| NER_MODEL_ID = "ICEF-NLP/bcms-bertic-comtext-sr-legal-ner-ekavica" | |
| # --- Load models once at startup --- | |
| print("Loading tokenizer...") | |
| hf_tokenizer = AutoTokenizer.from_pretrained(LEMMA_MODEL_ID) | |
| print("Loading lemma model...") | |
| lemma_model = AutoModelForTokenClassification.from_pretrained(LEMMA_MODEL_ID) | |
| lemma_model.eval() | |
| lemma_id2label = lemma_model.config.id2label | |
| print("Loading MSD model...") | |
| msd_model = AutoModelForTokenClassification.from_pretrained(MSD_MODEL_ID) | |
| msd_model.eval() | |
| msd_id2label = msd_model.config.id2label | |
| print("Loading NER model...") | |
| ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_ID) | |
| ner_model.eval() | |
| ner_id2label = ner_model.config.id2label | |
| print("All models loaded.") | |
| # --- Legal abbreviations --- | |
| LEGAL_ABBREVS = { | |
| 'čl', 'st', 'br', 'sl', 'dr', 'mr', 'prof', 'ing', 'tač', | |
| 'str', 'god', 'par', 'al', 'sek', 'tzv', 'itd', 'idr', | |
| 'jan', 'feb', 'mar', 'apr', 'maj', 'jun', | |
| 'jul', 'avg', 'sep', 'okt', 'nov', 'dec', | |
| 'stav', 'red', 'nn', 'rs', 'rh', 'bih', | |
| 'ul', 'bb', 'sp', 'tr', 'tel', 'fax', 'www' | |
| } | |
| # Company suffixes that should never be split — all variants | |
| COMPANY_SUFFIXES = [ | |
| r'd\.o\.o\.?', r'D\.O\.O\.?', | |
| r'a\.d\.?', r'A\.D\.?', | |
| r'd\.d\.?', r'D\.D\.?', | |
| r'j\.p\.?', r'J\.P\.?', | |
| r'k\.d\.?', r'K\.D\.?', | |
| r'o\.d\.?', r'O\.D\.?', | |
| r'doo', r'DOO', | |
| r'ad', r'AD', | |
| ] | |
| def protect_text(text): | |
| """Replace all dots/commas that are NOT boundaries with placeholders.""" | |
| # Protect URLs | |
| text = re.sub( | |
| r'https?://[^\s]+', | |
| lambda m: m.group(0).replace('.', '\x00').replace('/', '\x03').replace(':', '\x04'), | |
| text | |
| ) | |
| text = re.sub( | |
| r'www\.[^\s]+', | |
| lambda m: m.group(0).replace('.', '\x00'), | |
| text | |
| ) | |
| # Protect emails | |
| text = re.sub( | |
| r'[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}', | |
| lambda m: m.group(0).replace('.', '\x00').replace('@', '\x02'), | |
| text | |
| ) | |
| # Protect dates like 12.03.2024. or 12.03.2024 | |
| text = re.sub( | |
| r'\b\d{1,2}\.\d{1,2}\.\d{2,4}\.?', | |
| lambda m: m.group(0).replace('.', '\x00'), | |
| text | |
| ) | |
| # Protect money/amount decimals and thousands separators | |
| # e.g. 1.500.000,00 or 145.000,00 or 1.500 | |
| # dot between digits only when it looks like thousands separator (groups of 3) | |
| text = re.sub( | |
| r'\b(\d{1,3})(\.\d{3})+([,\x01]\d+)?\b', | |
| lambda m: m.group(0).replace('.', '\x00').replace(',', '\x01'), | |
| text | |
| ) | |
| # comma between digits = decimal separator in amounts | |
| text = re.sub(r'(\d),(\d)', lambda m: m.group(1) + '\x01' + m.group(2), text) | |
| # remaining dot between digits (e.g. simple decimals) | |
| text = re.sub(r'(\d)\.(\d)', lambda m: m.group(1) + '\x00' + m.group(2), text) | |
| # Protect company suffixes — all variants | |
| for suffix in COMPANY_SUFFIXES: | |
| text = re.sub( | |
| suffix, | |
| lambda m: m.group(0).replace('.', '\x00'), | |
| text, | |
| flags=re.IGNORECASE | |
| ) | |
| # Protect single uppercase initials like A. B. Ov. | |
| text = re.sub( | |
| r'(?<!\w)([A-ZŠĐČĆŽ])\.', | |
| lambda m: m.group(1) + '\x00', | |
| text | |
| ) | |
| # Protect chained letter.letter patterns like d.o.o. a.d. j.p. | |
| for _ in range(5): | |
| text = re.sub( | |
| r'(?<![^\s(„"\x00])([a-zA-Zšđč枊ĐČĆŽ]{1,3})\.([a-zA-Zšđč枊ĐČĆŽ])', | |
| lambda m: m.group(1) + '\x00' + m.group(2), | |
| text | |
| ) | |
| # Protect legal abbreviations | |
| for abbrev in sorted(LEGAL_ABBREVS, key=len, reverse=True): | |
| text = re.sub( | |
| re.escape(abbrev) + r'\.', | |
| lambda m: m.group(0).replace('.', '\x00'), | |
| text, | |
| flags=re.IGNORECASE | |
| ) | |
| return text | |
| def restore_text(text): | |
| return (text | |
| .replace('\x00', '.') | |
| .replace('\x01', ',') | |
| .replace('\x02', '@') | |
| .replace('\x03', '/') | |
| .replace('\x04', ':') | |
| ) | |
| def fix_quotes(tokens): | |
| """ | |
| Ensure Serbian quotes are balanced and paired. | |
| „ is opening, " is closing. | |
| If they end up separated from their content, rejoin them. | |
| Strategy: „ attaches to next token, " attaches to previous token. | |
| """ | |
| result = [] | |
| i = 0 | |
| while i < len(tokens): | |
| tok = tokens[i] | |
| # Opening quote „ — attach to next token if next is not also a quote | |
| if tok in ('„', '"') and i + 1 < len(tokens) and tokens[i + 1] not in ('"', '„', '"'): | |
| result.append(tok + tokens[i + 1]) | |
| i += 2 | |
| # Closing quote " — attach to previous token | |
| elif tok in ('"', '"') and result and result[-1] not in ('„', '"'): | |
| result[-1] = result[-1] + tok | |
| i += 1 | |
| else: | |
| result.append(tok) | |
| i += 1 | |
| return result | |
| def split_sentences(text): | |
| text = re.sub(r'\r\n', '\n', text) | |
| text = re.sub(r'[ \t]+', ' ', text) | |
| # Every \n is a hard sentence boundary | |
| raw_segments = text.split('\n') | |
| final_sentences = [] | |
| for segment in raw_segments: | |
| segment = segment.strip() | |
| if not segment: | |
| continue | |
| protected = protect_text(segment) | |
| tokens = protected.split(' ') | |
| current_sentence = [] | |
| for i, token in enumerate(tokens): | |
| current_sentence.append(token) | |
| if not token: | |
| continue | |
| ends_with_other = token.endswith('!') or token.endswith('?') | |
| ends_with_period = token.endswith('.') | |
| if ends_with_other: | |
| final_sentences.append(restore_text(' '.join(current_sentence).strip())) | |
| current_sentence = [] | |
| continue | |
| if ends_with_period: | |
| # Ordinal number — only split if next is uppercase | |
| if re.match(r'^\d+\.$', token): | |
| if i + 1 < len(tokens): | |
| next_tok = tokens[i + 1].lstrip('"\'„\x00(') | |
| if next_tok and (next_tok[0].islower() or next_tok[0].isdigit()): | |
| continue | |
| # Split if next token starts uppercase | |
| if i + 1 < len(tokens): | |
| next_tok = tokens[i + 1].lstrip('"\'„\x00(') | |
| if next_tok and next_tok[0].isupper(): | |
| final_sentences.append(restore_text(' '.join(current_sentence).strip())) | |
| current_sentence = [] | |
| else: | |
| final_sentences.append(restore_text(' '.join(current_sentence).strip())) | |
| current_sentence = [] | |
| if current_sentence: | |
| leftover = restore_text(' '.join(current_sentence).strip()) | |
| if leftover: | |
| final_sentences.append(leftover) | |
| return [s for s in final_sentences if s.strip()] | |
| def tokenize_sentence(sentence): | |
| # Protect everything | |
| sentence = protect_text(sentence) | |
| # Split on punctuation — comma is a token separator but NOT between digits | |
| # (digit commas already protected as \x01) | |
| sentence = re.sub(r"([;:!()\[\]{}\u00ab\u00bb\u2013\u2014])", r' \1 ', sentence) | |
| # Comma is a token separator (but NOT a sentence separator) | |
| # Only split commas that are NOT between digits (those are already \x01) | |
| sentence = re.sub(r'(?<!\d),(?!\d)', r' , ', sentence) | |
| sentence = re.sub(r'(?<=\d),(?!\d)', r' , ', sentence) | |
| sentence = re.sub(r'(?<!\d),(?=\d)', r' , ', sentence) | |
| # Split remaining dots | |
| sentence = re.sub(r'\.', r' . ', sentence) | |
| # Handle quotes — detach first so fix_quotes can work | |
| sentence = re.sub(r'„', ' „ ', sentence) | |
| sentence = re.sub(r'["""]', ' " ', sentence) | |
| sentence = re.sub(r'\s+', ' ', sentence).strip() | |
| # Restore | |
| sentence = restore_text(sentence) | |
| # Split into tokens | |
| tokens = [t for t in sentence.split(' ') if t.strip()] | |
| # Fix quote pairing | |
| tokens = fix_quotes(tokens) | |
| return tokens | |
| def _predict_chunk(model, id2label, words): | |
| encoding = hf_tokenizer( | |
| words, | |
| is_split_into_words=True, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=False | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist() | |
| word_ids = encoding.word_ids() | |
| word_labels = {} | |
| for idx, word_id in enumerate(word_ids): | |
| if word_id is None: | |
| continue | |
| if word_id not in word_labels: | |
| word_labels[word_id] = id2label[predictions[idx]] | |
| return [word_labels.get(i, "O") for i in range(len(words))] | |
| def predict_labels(model, id2label, words): | |
| MAX_WORDS = 300 | |
| OVERLAP = 30 | |
| if len(words) <= MAX_WORDS: | |
| return _predict_chunk(model, id2label, words) | |
| labels = [None] * len(words) | |
| start = 0 | |
| while start < len(words): | |
| end = min(start + MAX_WORDS, len(words)) | |
| chunk = words[start:end] | |
| chunk_labels = _predict_chunk(model, id2label, chunk) | |
| for i, label in enumerate(chunk_labels): | |
| global_idx = start + i | |
| if labels[global_idx] is None: | |
| labels[global_idx] = label | |
| if end == len(words): | |
| break | |
| start += MAX_WORDS - OVERLAP | |
| return [l if l is not None else "O" for l in labels] | |
| def apply_lemma_transform(token, label): | |
| try: | |
| transformation = ast.literal_eval(label) | |
| del_end, app_end, del_start, app_start = transformation | |
| result = token | |
| if del_end > 0: | |
| result = result[:-del_end] | |
| result = result + app_end | |
| if del_start > 0: | |
| result = result[del_start:] | |
| result = app_start + result | |
| return result | |
| except Exception: | |
| return token | |
| def build_conllu(doc_id, sentences_tokens, sentences_lemmas, sentences_msds, sentences_ners): | |
| lines = [f"# newdoc id = {doc_id}"] | |
| for sent_idx, (tokens, lemmas, msds, ners) in enumerate( | |
| zip(sentences_tokens, sentences_lemmas, sentences_msds, sentences_ners) | |
| ): | |
| lines.append(f"# sent_id = {doc_id}.{sent_idx + 1}") | |
| lines.append(f"# text = {' '.join(tokens)}") | |
| for i, (token, lemma, msd, ner) in enumerate(zip(tokens, lemmas, msds, ners)): | |
| lines.append(f"{i + 1}\t{token}\t{lemma}\t{msd}\t{ner}") | |
| lines.append("") | |
| return "\n".join(lines) + "\n" | |
| # --- API --- | |
| class ProcessRequest(BaseModel): | |
| text: str | |
| doc_id: str = "doc_1" | |
| class ProcessResponse(BaseModel): | |
| conllu: str | |
| def health(): | |
| return {"status": "ok"} | |
| def process(request: ProcessRequest): | |
| try: | |
| sentences = split_sentences(request.text) | |
| all_tokens, all_lemmas, all_msds, all_ners = [], [], [], [] | |
| for idx, sentence in enumerate(sentences): | |
| tokens = tokenize_sentence(sentence) | |
| if not tokens: | |
| continue | |
| if len(tokens) > 1000: | |
| print(f"Skipping sentence {idx} — too long ({len(tokens)} tokens)") | |
| continue | |
| try: | |
| lemma_labels = predict_labels(lemma_model, lemma_id2label, tokens) | |
| msd_labels = predict_labels(msd_model, msd_id2label, tokens) | |
| ner_labels = predict_labels(ner_model, ner_id2label, tokens) | |
| except Exception as e: | |
| print(f"ERROR on sentence {idx}: {sentence[:80]}") | |
| print(f"Tokens ({len(tokens)}): {tokens[:20]}") | |
| print(f"Error: {e}") | |
| continue | |
| lemmas = [apply_lemma_transform(tok, lbl) for tok, lbl in zip(tokens, lemma_labels)] | |
| all_tokens.append(tokens) | |
| all_lemmas.append(lemmas) | |
| all_msds.append(msd_labels) | |
| all_ners.append(ner_labels) | |
| if not all_tokens: | |
| raise ValueError("No sentences could be processed") | |
| conllu = build_conllu(request.doc_id, all_tokens, all_lemmas, all_msds, all_ners) | |
| return ProcessResponse(conllu=conllu) | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) |