from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForTokenClassification import torch import re import ast app = FastAPI() # --- Model IDs --- LEMMA_MODEL_ID = "ICEF-NLP/bcms-bertic-comtext-sr-legal-lemma-ekavica" MSD_MODEL_ID = "ICEF-NLP/bcms-bertic-comtext-sr-legal-msd-ekavica" NER_MODEL_ID = "ICEF-NLP/bcms-bertic-comtext-sr-legal-ner-ekavica" # --- Load models once at startup --- print("Loading tokenizer...") hf_tokenizer = AutoTokenizer.from_pretrained(LEMMA_MODEL_ID) print("Loading lemma model...") lemma_model = AutoModelForTokenClassification.from_pretrained(LEMMA_MODEL_ID) lemma_model.eval() lemma_id2label = lemma_model.config.id2label print("Loading MSD model...") msd_model = AutoModelForTokenClassification.from_pretrained(MSD_MODEL_ID) msd_model.eval() msd_id2label = msd_model.config.id2label print("Loading NER model...") ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_ID) ner_model.eval() ner_id2label = ner_model.config.id2label print("All models loaded.") # --- Legal abbreviations --- LEGAL_ABBREVS = { 'čl', 'st', 'br', 'sl', 'dr', 'mr', 'prof', 'ing', 'tač', 'str', 'god', 'par', 'al', 'sek', 'tzv', 'itd', 'idr', 'jan', 'feb', 'mar', 'apr', 'maj', 'jun', 'jul', 'avg', 'sep', 'okt', 'nov', 'dec', 'stav', 'red', 'nn', 'rs', 'rh', 'bih', 'ul', 'bb', 'sp', 'tr', 'tel', 'fax', 'www' } # Company suffixes that should never be split — all variants COMPANY_SUFFIXES = [ r'd\.o\.o\.?', r'D\.O\.O\.?', r'a\.d\.?', r'A\.D\.?', r'd\.d\.?', r'D\.D\.?', r'j\.p\.?', r'J\.P\.?', r'k\.d\.?', r'K\.D\.?', r'o\.d\.?', r'O\.D\.?', r'doo', r'DOO', r'ad', r'AD', ] def protect_text(text): """Replace all dots/commas that are NOT boundaries with placeholders.""" # Protect URLs text = re.sub( r'https?://[^\s]+', lambda m: m.group(0).replace('.', '\x00').replace('/', '\x03').replace(':', '\x04'), text ) text = re.sub( r'www\.[^\s]+', lambda m: m.group(0).replace('.', '\x00'), text ) # Protect emails text = re.sub( r'[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}', lambda m: m.group(0).replace('.', '\x00').replace('@', '\x02'), text ) # Protect dates like 12.03.2024. or 12.03.2024 text = re.sub( r'\b\d{1,2}\.\d{1,2}\.\d{2,4}\.?', lambda m: m.group(0).replace('.', '\x00'), text ) # Protect money/amount decimals and thousands separators # e.g. 1.500.000,00 or 145.000,00 or 1.500 # dot between digits only when it looks like thousands separator (groups of 3) text = re.sub( r'\b(\d{1,3})(\.\d{3})+([,\x01]\d+)?\b', lambda m: m.group(0).replace('.', '\x00').replace(',', '\x01'), text ) # comma between digits = decimal separator in amounts text = re.sub(r'(\d),(\d)', lambda m: m.group(1) + '\x01' + m.group(2), text) # remaining dot between digits (e.g. simple decimals) text = re.sub(r'(\d)\.(\d)', lambda m: m.group(1) + '\x00' + m.group(2), text) # Protect company suffixes — all variants for suffix in COMPANY_SUFFIXES: text = re.sub( suffix, lambda m: m.group(0).replace('.', '\x00'), text, flags=re.IGNORECASE ) # Protect single uppercase initials like A. B. Ov. text = re.sub( r'(? 0: result = result[:-del_end] result = result + app_end if del_start > 0: result = result[del_start:] result = app_start + result return result except Exception: return token def build_conllu(doc_id, sentences_tokens, sentences_lemmas, sentences_msds, sentences_ners): lines = [f"# newdoc id = {doc_id}"] for sent_idx, (tokens, lemmas, msds, ners) in enumerate( zip(sentences_tokens, sentences_lemmas, sentences_msds, sentences_ners) ): lines.append(f"# sent_id = {doc_id}.{sent_idx + 1}") lines.append(f"# text = {' '.join(tokens)}") for i, (token, lemma, msd, ner) in enumerate(zip(tokens, lemmas, msds, ners)): lines.append(f"{i + 1}\t{token}\t{lemma}\t{msd}\t{ner}") lines.append("") return "\n".join(lines) + "\n" # --- API --- class ProcessRequest(BaseModel): text: str doc_id: str = "doc_1" class ProcessResponse(BaseModel): conllu: str @app.get("/health") def health(): return {"status": "ok"} @app.post("/process", response_model=ProcessResponse) def process(request: ProcessRequest): try: sentences = split_sentences(request.text) all_tokens, all_lemmas, all_msds, all_ners = [], [], [], [] for idx, sentence in enumerate(sentences): tokens = tokenize_sentence(sentence) if not tokens: continue if len(tokens) > 1000: print(f"Skipping sentence {idx} — too long ({len(tokens)} tokens)") continue try: lemma_labels = predict_labels(lemma_model, lemma_id2label, tokens) msd_labels = predict_labels(msd_model, msd_id2label, tokens) ner_labels = predict_labels(ner_model, ner_id2label, tokens) except Exception as e: print(f"ERROR on sentence {idx}: {sentence[:80]}") print(f"Tokens ({len(tokens)}): {tokens[:20]}") print(f"Error: {e}") continue lemmas = [apply_lemma_transform(tok, lbl) for tok, lbl in zip(tokens, lemma_labels)] all_tokens.append(tokens) all_lemmas.append(lemmas) all_msds.append(msd_labels) all_ners.append(ner_labels) if not all_tokens: raise ValueError("No sentences could be processed") conllu = build_conllu(request.doc_id, all_tokens, all_lemmas, all_msds, all_ners) return ProcessResponse(conllu=conllu) except Exception as e: import traceback print(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e))