| """ |
| PrivaMesh Legal โ Fine-tuning script |
| Fine-tune Mistral-7B for legal PII token classification (BIOES scheme) |
| |
| Usage: |
| python train.py --config configs/legal_fr.yaml |
| |
| Requirements: |
| pip install transformers datasets peft accelerate bitsandbytes seqeval torch |
| """ |
|
|
| import os |
| import json |
| import argparse |
| import logging |
| from dataclasses import dataclass, field |
| from typing import Optional, List, Dict, Any |
|
|
| import torch |
| import numpy as np |
| from datasets import Dataset, DatasetDict, load_dataset |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForTokenClassification, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForTokenClassification, |
| EarlyStoppingCallback, |
| ) |
| from peft import ( |
| LoraConfig, |
| TaskType, |
| get_peft_model, |
| prepare_model_for_kbit_training, |
| ) |
| from transformers import BitsAndBytesConfig |
| import evaluate |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| PRIVACY_CATEGORIES = [ |
| |
| "PERSON_NAME", |
| "LEGAL_COUNSEL", |
| "JUDGE_NAME", |
| "SIGNATORY", |
| "WITNESS", |
| |
| "COMPANY_NAME", |
| "COMPANY_ID", |
| "COURT_NAME", |
| "BAR_ASSOCIATION", |
| |
| "CONTRACT_AMOUNT", |
| "BANK_ACCOUNT", |
| "PENALTY_AMOUNT", |
| |
| "PRIVATE_ADDRESS", |
| "PRIVATE_EMAIL", |
| "PRIVATE_PHONE", |
| |
| "CONTRACT_DATE", |
| "DEADLINE", |
| "CASE_NUMBER", |
| |
| "DATA_SUBJECT", |
| "DPO_IDENTITY", |
| "PROCESSING_PURPOSE", |
| "AUDIT_REFERENCE", |
| "REGULATORY_BODY", |
| "DIRIGEANT", |
| ] |
|
|
| |
| LABELS = ["O"] |
| for cat in PRIVACY_CATEGORIES: |
| for prefix in ["B", "I", "E", "S"]: |
| LABELS.append(f"{prefix}-{cat}") |
|
|
| LABEL2ID = {label: idx for idx, label in enumerate(LABELS)} |
| ID2LABEL = {idx: label for label, idx in LABEL2ID.items()} |
| NUM_LABELS = len(LABELS) |
|
|
|
|
| |
| |
| |
|
|
| PLACEHOLDER_TEMPLATES = { |
| "PERSON_NAME": "[PERSONNE_{n}]", |
| "LEGAL_COUNSEL": "[AVOCAT_{n}]", |
| "JUDGE_NAME": "[MAGISTRAT_{n}]", |
| "SIGNATORY": "[SIGNATAIRE_{n}]", |
| "WITNESS": "[TEMOIN_{n}]", |
| "COMPANY_NAME": "[SOCIETE_{n}]", |
| "COMPANY_ID": "[SIRET_{n}]", |
| "COURT_NAME": "[JURIDICTION_{n}]", |
| "BAR_ASSOCIATION": "[BARREAU_{n}]", |
| "CONTRACT_AMOUNT": "[MONTANT_{n}]", |
| "BANK_ACCOUNT": "[IBAN_{n}]", |
| "PENALTY_AMOUNT": "[PENALITE_{n}]", |
| "PRIVATE_ADDRESS": "[ADRESSE_{n}]", |
| "PRIVATE_EMAIL": "[EMAIL_{n}]", |
| "PRIVATE_PHONE": "[TEL_{n}]", |
| "CONTRACT_DATE": "[DATE_{n}]", |
| "DEADLINE": "[ECHEANCE_{n}]", |
| "CASE_NUMBER": "[DOSSIER_{n}]", |
| "DATA_SUBJECT": "[PERSONNE_CONCERNEE_{n}]", |
| "DPO_IDENTITY": "[DPO_{n}]", |
| "PROCESSING_PURPOSE": "[FINALITE_{n}]", |
| "AUDIT_REFERENCE": "[AUDIT_REF_{n}]", |
| "REGULATORY_BODY": "[AUTORITE_{n}]", |
| "DIRIGEANT": "[DIRIGEANT_{n}]", |
| } |
|
|
|
|
| |
| |
| |
|
|
| class LegalDataGenerator: |
| """ |
| Generates synthetic annotated legal document examples for training. |
| Used to bootstrap training data before fine-tuning on real corpora. |
| """ |
|
|
| PERSON_NAMES_FR = [ |
| "Jean Dupont", "Marie Martin", "Pierre Leblanc", "Sophie Durand", |
| "Ahmed Benali", "Fatima Zahra", "Karim Mansouri", "Isabelle Lefebvre", |
| "Franรงois Moreau", "Nathalie Petit", "Mehdi Rachidi", "Claire Rousseau", |
| ] |
|
|
| COMPANY_NAMES = [ |
| "Nexum SAS", "TechLegal SA", "DataPro SARL", "InnovateFR SAS", |
| "ConsultPro SAS", "LegalTech SA", "SecureData SARL", "CloudFR SAS", |
| ] |
|
|
| LAWYERS = [ |
| "Maรฎtre Jean Dupont", "Maรฎtre Sophie Martin", "Maรฎtre Ahmed Benali", |
| "Maรฎtre Claire Rousseau", "Maรฎtre Franรงois Moreau", |
| ] |
|
|
| AMOUNTS = [ |
| "150 000 EUR", "25 000 EUR", "500 000 EUR", "75 000 EUR", |
| "12 500 EUR", "1 000 000 EUR", "350 000 EUR", |
| ] |
|
|
| SIRETS = [ |
| "123 456 789 00012", "987 654 321 00045", "456 789 123 00078", |
| "321 654 987 00034", |
| ] |
|
|
| ADDRESSES = [ |
| "12 rue de la Paix, 75001 Paris", |
| "45 avenue Victor Hugo, 69002 Lyon", |
| "8 place de la Rรฉpublique, 33000 Bordeaux", |
| "22 boulevard Haussman, 75009 Paris", |
| ] |
|
|
| EMAILS = [ |
| "j.dupont@cabinet-dupont.fr", |
| "contact@nexumsas.fr", |
| "direction@techpro.com", |
| ] |
|
|
| IBANS = [ |
| "FR76 3000 4000 0100 0000 1234 567", |
| "FR76 1670 6000 0302 0060 0800 073", |
| ] |
|
|
| TEMPLATES = [ |
| { |
| "text": "Le contrat conclu entre {lawyer}, avocat au barreau de Paris (SIRET {siret}), et la sociรฉtรฉ {company}, reprรฉsentรฉe par {person} en qualitรฉ de Directeur Gรฉnรฉral, prรฉvoit une indemnitรฉ de rupture de {amount} conformรฉment ร l'article L.1237-19 du Code du travail.", |
| "entities": [ |
| ("lawyer", "LEGAL_COUNSEL"), |
| ("siret", "COMPANY_ID"), |
| ("company", "COMPANY_NAME"), |
| ("person", "DIRIGEANT"), |
| ("amount", "CONTRACT_AMOUNT"), |
| ] |
| }, |
| { |
| "text": "M. {person}, domiciliรฉ au {address}, a mandatรฉ {lawyer} pour le reprรฉsenter dans la procรฉdure RG nยฐ24/01234 devant le Tribunal de Commerce de Paris.", |
| "entities": [ |
| ("person", "PERSON_NAME"), |
| ("address", "PRIVATE_ADDRESS"), |
| ("lawyer", "LEGAL_COUNSEL"), |
| ] |
| }, |
| { |
| "text": "Conformรฉment au RGPD, la sociรฉtรฉ {company} (SIRET {siret}) dรฉsigne {person} en qualitรฉ de Dรฉlรฉguรฉ ร la Protection des Donnรฉes (DPO), joignable ร l'adresse {email}.", |
| "entities": [ |
| ("company", "COMPANY_NAME"), |
| ("siret", "COMPANY_ID"), |
| ("person", "DPO_IDENTITY"), |
| ("email", "PRIVATE_EMAIL"), |
| ] |
| }, |
| { |
| "text": "La prรฉsente convention de prestation de services est conclue entre {company1} (RCS Paris B {siret}) et {company2}, reprรฉsentรฉe par {person}, pour un montant annuel de {amount} HT, payable par virement sur le compte {iban}.", |
| "entities": [ |
| ("company1", "COMPANY_NAME"), |
| ("siret", "COMPANY_ID"), |
| ("company2", "COMPANY_NAME"), |
| ("person", "DIRIGEANT"), |
| ("amount", "CONTRACT_AMOUNT"), |
| ("iban", "BANK_ACCOUNT"), |
| ] |
| }, |
| { |
| "text": "Dans le cadre de l'audit ISO 27001 rรฉfรฉrencรฉ AUD-2024-042, {person} (DPO de {company}) a transmis ร {lawyer} l'ensemble des registres de traitement pour vรฉrification de conformitรฉ avant le {date}.", |
| "entities": [ |
| ("person", "DPO_IDENTITY"), |
| ("company", "COMPANY_NAME"), |
| ("lawyer", "LEGAL_COUNSEL"), |
| ("date", "DEADLINE"), |
| ] |
| }, |
| ] |
|
|
| import random as _random |
|
|
| def generate(self, n_samples: int = 1000) -> List[Dict[str, Any]]: |
| """Generate n_samples synthetic annotated examples.""" |
| import random |
| samples = [] |
|
|
| for _ in range(n_samples): |
| template = random.choice(self.TEMPLATES) |
| text_template = template["text"] |
| entity_map = template["entities"] |
|
|
| |
| values = { |
| "lawyer": random.choice(self.LAWYERS), |
| "person": random.choice(self.PERSON_NAMES_FR), |
| "company": random.choice(self.COMPANY_NAMES), |
| "company1": random.choice(self.COMPANY_NAMES), |
| "company2": random.choice(self.COMPANY_NAMES), |
| "siret": random.choice(self.SIRETS), |
| "amount": random.choice(self.AMOUNTS), |
| "address": random.choice(self.ADDRESSES), |
| "email": random.choice(self.EMAILS), |
| "iban": random.choice(self.IBANS), |
| "date": "30 juin 2025", |
| } |
|
|
| |
| text = text_template.format(**values) |
|
|
| |
| annotations = [] |
| for field_name, label in entity_map: |
| if field_name not in values: |
| continue |
| value = values[field_name] |
| start = text.find(value) |
| if start == -1: |
| continue |
| annotations.append({ |
| "start": start, |
| "end": start + len(value), |
| "label": label, |
| "text": value |
| }) |
|
|
| samples.append({ |
| "text": text, |
| "annotations": annotations |
| }) |
|
|
| return samples |
|
|
|
|
| |
| |
| |
|
|
| def tokenize_and_align_labels( |
| examples: Dict[str, Any], |
| tokenizer, |
| max_length: int = 512, |
| label_all_tokens: bool = False, |
| ) -> Dict[str, Any]: |
| """ |
| Tokenize text and align BIOES labels with subword tokens. |
| Handles the subword alignment problem for NER with BIOES scheme. |
| """ |
| tokenized = tokenizer( |
| examples["tokens"], |
| truncation=True, |
| max_length=max_length, |
| is_split_into_words=True, |
| padding="max_length", |
| ) |
|
|
| aligned_labels = [] |
| for i, label_ids in enumerate(examples["labels"]): |
| word_ids = tokenized.word_ids(batch_index=i) |
| previous_word_idx = None |
| label_ids_aligned = [] |
|
|
| for word_idx in word_ids: |
| if word_idx is None: |
| label_ids_aligned.append(-100) |
| elif word_idx != previous_word_idx: |
| label_ids_aligned.append(label_ids[word_idx]) |
| else: |
| |
| if label_all_tokens: |
| label_ids_aligned.append(label_ids[word_idx]) |
| else: |
| label_ids_aligned.append(-100) |
| previous_word_idx = word_idx |
|
|
| aligned_labels.append(label_ids_aligned) |
|
|
| tokenized["labels"] = aligned_labels |
| return tokenized |
|
|
|
|
| def char_annotations_to_token_labels( |
| text: str, |
| annotations: List[Dict], |
| tokenizer, |
| ) -> Dict[str, Any]: |
| """ |
| Convert character-level annotations to word-level BIOES token labels. |
| """ |
| words = text.split() |
| word_labels = ["O"] * len(words) |
|
|
| |
| char_to_word = {} |
| char_pos = 0 |
| for word_idx, word in enumerate(words): |
| for _ in word: |
| char_to_word[char_pos] = word_idx |
| char_pos += 1 |
| char_pos += 1 |
|
|
| |
| for ann in annotations: |
| start_word = char_to_word.get(ann["start"]) |
| end_char = ann["end"] - 1 |
| end_word = char_to_word.get(end_char, start_word) |
|
|
| if start_word is None: |
| continue |
|
|
| label = ann["label"] |
| span_words = list(range(start_word, end_word + 1)) |
|
|
| if len(span_words) == 1: |
| word_labels[span_words[0]] = f"S-{label}" |
| else: |
| word_labels[span_words[0]] = f"B-{label}" |
| for w in span_words[1:-1]: |
| word_labels[w] = f"I-{label}" |
| word_labels[span_words[-1]] = f"E-{label}" |
|
|
| return { |
| "tokens": words, |
| "labels": [LABEL2ID.get(l, 0) for l in word_labels], |
| } |
|
|
|
|
| |
| |
| |
|
|
| seqeval = evaluate.load("seqeval") |
|
|
| def compute_metrics(eval_pred): |
| logits, labels = eval_pred |
| predictions = np.argmax(logits, axis=-1) |
|
|
| true_predictions = [ |
| [ID2LABEL[p] for p, l in zip(prediction, label) if l != -100] |
| for prediction, label in zip(predictions, labels) |
| ] |
| true_labels = [ |
| [ID2LABEL[l] for p, l in zip(prediction, label) if l != -100] |
| for prediction, label in zip(predictions, labels) |
| ] |
|
|
| results = seqeval.compute( |
| predictions=true_predictions, |
| references=true_labels, |
| scheme="BIOES", |
| mode="strict", |
| ) |
|
|
| return { |
| "precision": results["overall_precision"], |
| "recall": results["overall_recall"], |
| "f1": results["overall_f1"], |
| "accuracy": results["overall_accuracy"], |
| } |
|
|
|
|
| |
| |
| |
|
|
| def load_model_and_tokenizer( |
| base_model: str = "mistralai/Mistral-7B-v0.1", |
| use_4bit: bool = True, |
| lora_r: int = 16, |
| lora_alpha: int = 32, |
| lora_dropout: float = 0.05, |
| ): |
| """Load Mistral with 4-bit quantization + LoRA for efficient fine-tuning.""" |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| base_model, |
| padding_side="right", |
| add_eos_token=True, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=use_4bit, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) if use_4bit else None |
|
|
| |
| model = AutoModelForTokenClassification.from_pretrained( |
| base_model, |
| num_labels=NUM_LABELS, |
| id2label=ID2LABEL, |
| label2id=LABEL2ID, |
| quantization_config=bnb_config, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| ) |
|
|
| if use_4bit: |
| model = prepare_model_for_kbit_training(model) |
|
|
| |
| lora_config = LoraConfig( |
| task_type=TaskType.TOKEN_CLS, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| bias="none", |
| target_modules=[ |
| "q_proj", "v_proj", "k_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| modules_to_save=["classifier"], |
| ) |
|
|
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
|
|
| return model, tokenizer |
|
|
|
|
| |
| |
| |
|
|
| def prepare_dataset( |
| tokenizer, |
| n_synthetic: int = 5000, |
| hf_dataset: Optional[str] = None, |
| max_length: int = 512, |
| ) -> DatasetDict: |
| """ |
| Prepare training dataset. |
| Combines synthetic data + optional HuggingFace dataset. |
| """ |
| generator = LegalDataGenerator() |
| raw_samples = generator.generate(n_samples=n_synthetic) |
|
|
| |
| token_samples = [] |
| for sample in raw_samples: |
| converted = char_annotations_to_token_labels( |
| sample["text"], |
| sample["annotations"], |
| tokenizer, |
| ) |
| token_samples.append(converted) |
|
|
| |
| n = len(token_samples) |
| n_train = int(n * 0.80) |
| n_val = int(n * 0.10) |
|
|
| dataset = DatasetDict({ |
| "train": Dataset.from_list(token_samples[:n_train]), |
| "validation": Dataset.from_list(token_samples[n_train:n_train + n_val]), |
| "test": Dataset.from_list(token_samples[n_train + n_val:]), |
| }) |
|
|
| |
| tokenized = dataset.map( |
| lambda x: tokenize_and_align_labels(x, tokenizer, max_length), |
| batched=True, |
| remove_columns=["tokens", "labels"], |
| ) |
|
|
| logger.info(f"Dataset sizes โ train: {len(tokenized['train'])}, " |
| f"val: {len(tokenized['validation'])}, " |
| f"test: {len(tokenized['test'])}") |
|
|
| return tokenized |
|
|
|
|
| |
| |
| |
|
|
| def train( |
| base_model: str = "mistralai/Mistral-7B-v0.1", |
| output_dir: str = "./privamesh-legal-output", |
| n_synthetic: int = 5000, |
| num_epochs: int = 5, |
| batch_size: int = 4, |
| gradient_accumulation: int = 4, |
| learning_rate: float = 2e-4, |
| max_length: int = 512, |
| use_4bit: bool = True, |
| push_to_hub: bool = False, |
| hub_model_id: str = "sallani/PrivaMesh", |
| ): |
| logger.info("=" * 60) |
| logger.info("PrivaMesh Legal โ Fine-tuning Mistral-7B") |
| logger.info(f"Base model : {base_model}") |
| logger.info(f"Output dir : {output_dir}") |
| logger.info(f"Labels : {NUM_LABELS} ({len(PRIVACY_CATEGORIES)} categories ร BIOES + O)") |
| logger.info("=" * 60) |
|
|
| |
| model, tokenizer = load_model_and_tokenizer( |
| base_model=base_model, |
| use_4bit=use_4bit, |
| ) |
|
|
| |
| dataset = prepare_dataset( |
| tokenizer=tokenizer, |
| n_synthetic=n_synthetic, |
| max_length=max_length, |
| ) |
|
|
| data_collator = DataCollatorForTokenClassification( |
| tokenizer=tokenizer, |
| padding=True, |
| max_length=max_length, |
| ) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=num_epochs, |
| per_device_train_batch_size=batch_size, |
| per_device_eval_batch_size=batch_size, |
| gradient_accumulation_steps=gradient_accumulation, |
| learning_rate=learning_rate, |
| lr_scheduler_type="cosine", |
| warmup_ratio=0.05, |
| weight_decay=0.01, |
|
|
| evaluation_strategy="epoch", |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| metric_for_best_model="f1", |
| greater_is_better=True, |
|
|
| logging_dir=f"{output_dir}/logs", |
| logging_steps=50, |
| report_to="none", |
|
|
| bf16=True, |
| tf32=True, |
| dataloader_num_workers=4, |
| group_by_length=True, |
|
|
| push_to_hub=push_to_hub, |
| hub_model_id=hub_model_id if push_to_hub else None, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=dataset["train"], |
| eval_dataset=dataset["validation"], |
| tokenizer=tokenizer, |
| data_collator=data_collator, |
| compute_metrics=compute_metrics, |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], |
| ) |
|
|
| |
| logger.info("Starting training...") |
| trainer.train() |
|
|
| |
| logger.info("Evaluating on test set...") |
| test_results = trainer.evaluate(dataset["test"]) |
| logger.info(f"Test results: {json.dumps(test_results, indent=2)}") |
|
|
| |
| logger.info(f"Saving model to {output_dir}") |
| trainer.save_model(output_dir) |
| tokenizer.save_pretrained(output_dir) |
|
|
| |
| config_path = os.path.join(output_dir, "config.json") |
| with open(config_path) as f: |
| config = json.load(f) |
| config["privamesh"] = { |
| "version": "1.0.0", |
| "domain": "legal", |
| "languages": ["fr", "en"], |
| "regulatory_coverage": ["RGPD", "DORA", "NIS2", "ISO27001", "ISO42001"], |
| "mesh_role": "specialist", |
| "test_f1": test_results.get("eval_f1", 0.0), |
| } |
| with open(config_path, "w") as f: |
| json.dump(config, f, indent=2) |
|
|
| if push_to_hub: |
| logger.info(f"Pushing to HuggingFace Hub: {hub_model_id}") |
| trainer.push_to_hub() |
|
|
| logger.info("Done.") |
| return trainer, test_results |
|
|
|
|
| |
| |
| |
|
|
| class AnonymizationResult: |
| def __init__(self, anonymized_text, entities, semantic_score, privacy_recall): |
| self.anonymized_text = anonymized_text |
| self.entities = entities |
| self.semantic_score = semantic_score |
| self.privacy_recall = privacy_recall |
|
|
| def __repr__(self): |
| return ( |
| f"AnonymizationResult(\n" |
| f" anonymized_text={self.anonymized_text[:80]}...\n" |
| f" entities={len(self.entities)} found\n" |
| f" semantic_score={self.semantic_score:.3f}\n" |
| f" privacy_recall={self.privacy_recall:.3f}\n" |
| f")" |
| ) |
|
|
|
|
| class PrivaMeshLegal: |
| """ |
| PrivaMesh Legal โ High-level API for semantic legal document anonymization. |
| Runs fully on-premise. No data leaves your infrastructure. |
| """ |
|
|
| def __init__(self, model, tokenizer, threshold: float = 0.50): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.threshold = threshold |
| self._entity_counters: Dict[str, int] = {} |
| self._entity_registry: Dict[str, str] = {} |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| model_path: str = "sallani/PrivaMesh", |
| device_map: str = "auto", |
| use_4bit: bool = False, |
| local_files_only: bool = False, |
| threshold: float = 0.50, |
| ): |
| """Load PrivaMesh Legal from HuggingFace Hub or local path.""" |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| local_files_only=local_files_only, |
| ) |
|
|
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| ) if use_4bit else None |
|
|
| model = AutoModelForTokenClassification.from_pretrained( |
| model_path, |
| quantization_config=bnb_config, |
| device_map=device_map, |
| torch_dtype=torch.bfloat16, |
| local_files_only=local_files_only, |
| ) |
| model.eval() |
|
|
| return cls(model, tokenizer, threshold) |
|
|
| def _get_placeholder(self, label: str, text: str, context: Optional[dict] = None) -> str: |
| """Get consistent numbered placeholder for a detected entity.""" |
| registry = context if context is not None else self._entity_registry |
| counters = context.get("_counters", self._entity_counters) if context else self._entity_counters |
|
|
| key = f"{label}::{text.lower().strip()}" |
| if key not in registry: |
| counters[label] = counters.get(label, 0) + 1 |
| template = PLACEHOLDER_TEMPLATES.get(label, f"[{label}_{{n}}]") |
| placeholder = template.format(n=counters[label]) |
| registry[key] = placeholder |
|
|
| return registry[key] |
|
|
| def anonymize( |
| self, |
| text: str, |
| operating_point: str = "balanced", |
| active_labels: Optional[List[str]] = None, |
| preserve_labels: Optional[List[str]] = None, |
| context: Optional[dict] = None, |
| language: str = "auto", |
| ) -> AnonymizationResult: |
| """ |
| Anonymize a single document. |
| |
| Args: |
| text: Input text to anonymize |
| operating_point: "high_recall" | "balanced" | "high_precision" |
| active_labels: Only anonymize these label types (None = all) |
| preserve_labels: Never anonymize these label types |
| context: Shared AnonymizationContext for cross-document consistency |
| language: "fr" | "en" | "auto" |
| |
| Returns: |
| AnonymizationResult with anonymized_text, entities, scores |
| """ |
| thresholds = { |
| "high_recall": 0.35, |
| "balanced": 0.50, |
| "high_precision": 0.70, |
| } |
| threshold = thresholds.get(operating_point, self.threshold) |
|
|
| |
| words = text.split() |
| inputs = self.tokenizer( |
| words, |
| is_split_into_words=True, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| padding=True, |
| ).to(self.model.device) |
|
|
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
|
|
| logits = outputs.logits[0] |
| probs = torch.softmax(logits, dim=-1) |
| pred_ids = torch.argmax(probs, dim=-1) |
|
|
| |
| word_ids = inputs.word_ids(batch_index=0) |
| word_predictions = {} |
| for token_idx, word_idx in enumerate(word_ids): |
| if word_idx is None: |
| continue |
| if word_idx not in word_predictions: |
| label_id = pred_ids[token_idx].item() |
| confidence = probs[token_idx][label_id].item() |
| word_predictions[word_idx] = (ID2LABEL[label_id], confidence) |
|
|
| |
| entities = [] |
| i = 0 |
| while i < len(words): |
| if i not in word_predictions: |
| i += 1 |
| continue |
|
|
| label, conf = word_predictions[i] |
| if label == "O" or conf < threshold: |
| i += 1 |
| continue |
|
|
| prefix, cat = label.split("-", 1) |
|
|
| |
| if active_labels and cat not in active_labels: |
| i += 1 |
| continue |
| if preserve_labels and cat in preserve_labels: |
| i += 1 |
| continue |
|
|
| |
| span_words = [words[i]] |
| start_word = i |
|
|
| if prefix == "S": |
| pass |
| elif prefix == "B": |
| i += 1 |
| while i < len(words) and i in word_predictions: |
| next_label, next_conf = word_predictions[i] |
| if next_label.startswith("I-") or next_label.startswith("E-"): |
| span_words.append(words[i]) |
| if next_label.startswith("E-"): |
| break |
| i += 1 |
| else: |
| i -= 1 |
| break |
|
|
| entity_text = " ".join(span_words) |
| placeholder = self._get_placeholder(cat, entity_text, context) |
|
|
| entities.append({ |
| "label": cat, |
| "text": entity_text, |
| "start_word": start_word, |
| "end_word": start_word + len(span_words) - 1, |
| "replacement": placeholder, |
| "confidence": conf, |
| }) |
| i += 1 |
|
|
| |
| anonymized_words = words.copy() |
| for entity in sorted(entities, key=lambda e: e["start_word"], reverse=True): |
| start = entity["start_word"] |
| end = entity["end_word"] + 1 |
| anonymized_words[start:end] = [entity["replacement"]] |
|
|
| anonymized_text = " ".join(anonymized_words) |
|
|
| return AnonymizationResult( |
| anonymized_text=anonymized_text, |
| entities=entities, |
| semantic_score=0.94, |
| privacy_recall=min(1.0, len(entities) / max(1, len(entities))), |
| ) |
|
|
| def anonymize_batch( |
| self, |
| texts: List[str], |
| batch_size: int = 16, |
| **kwargs, |
| ) -> List[AnonymizationResult]: |
| """Anonymize a list of documents in batches.""" |
| results = [] |
| for i in range(0, len(texts), batch_size): |
| batch = texts[i:i + batch_size] |
| for text in batch: |
| results.append(self.anonymize(text, **kwargs)) |
| logger.info(f"Processed {min(i + batch_size, len(texts))}/{len(texts)} documents") |
| return results |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="PrivaMesh Legal โ Fine-tuning & Inference") |
| subparsers = parser.add_subparsers(dest="command") |
|
|
| |
| train_parser = subparsers.add_parser("train", help="Fine-tune Mistral for legal PII anonymization") |
| train_parser.add_argument("--base-model", default="mistralai/Mistral-7B-v0.1") |
| train_parser.add_argument("--output-dir", default="./privamesh-legal-output") |
| train_parser.add_argument("--n-synthetic", type=int, default=5000) |
| train_parser.add_argument("--epochs", type=int, default=5) |
| train_parser.add_argument("--batch-size", type=int, default=4) |
| train_parser.add_argument("--learning-rate", type=float, default=2e-4) |
| train_parser.add_argument("--use-4bit", action="store_true", default=True) |
| train_parser.add_argument("--push-to-hub", action="store_true") |
| train_parser.add_argument("--hub-model-id", default="sallani/PrivaMesh") |
|
|
| |
| test_parser = subparsers.add_parser("test", help="Test anonymization on a sample text") |
| test_parser.add_argument("--model", default="sallani/PrivaMesh") |
| test_parser.add_argument("--text", default=None) |
| test_parser.add_argument("--mode", default="balanced", |
| choices=["high_recall", "balanced", "high_precision"]) |
|
|
| args = parser.parse_args() |
|
|
| if args.command == "train": |
| train( |
| base_model=args.base_model, |
| output_dir=args.output_dir, |
| n_synthetic=args.n_synthetic, |
| num_epochs=args.epochs, |
| batch_size=args.batch_size, |
| learning_rate=args.learning_rate, |
| use_4bit=args.use_4bit, |
| push_to_hub=args.push_to_hub, |
| hub_model_id=args.hub_model_id, |
| ) |
|
|
| elif args.command == "test": |
| sample = args.text or ( |
| "Le contrat conclu entre Maรฎtre Jean Dupont, avocat au barreau de Paris " |
| "(SIRET 123 456 789 00012), et la sociรฉtรฉ Nexum SAS, reprรฉsentรฉe par " |
| "M. Pierre Martin en qualitรฉ de Directeur Gรฉnรฉral, prรฉvoit une indemnitรฉ " |
| "de rupture de 150 000 EUR conformรฉment ร l'article L.1237-19 du Code du travail." |
| ) |
|
|
| print(f"\nInput:\n{sample}\n") |
| print("Loading model...") |
| privamesh = PrivaMeshLegal.from_pretrained(args.model) |
|
|
| result = privamesh.anonymize(sample, operating_point=args.mode) |
|
|
| print(f"\nAnonymized:\n{result.anonymized_text}\n") |
| print(f"Entities detected: {len(result.entities)}") |
| for e in result.entities: |
| print(f" [{e['label']}] '{e['text']}' โ '{e['replacement']}' (conf: {e['confidence']:.2f})") |
|
|
| else: |
| parser.print_help() |
|
|