| """ |
| Configuration and utility functions for Rwanda Legal NLP System |
| """ |
|
|
| import os |
| import json |
| import logging |
| import pandas as pd |
| from typing import Dict, Any |
| from dataclasses import dataclass |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
|
|
| @dataclass |
| class ModelConfig: |
| """Configuration for model settings""" |
| model_name: str = "EleutherAI/gpt-j-6B" |
| max_length: int = 512 |
| temperature: float = 0.7 |
| top_p: float = 0.9 |
| top_k: int = 50 |
| do_sample: bool = True |
| num_return_sequences: int = 1 |
|
|
| @dataclass |
| class TrainingConfig: |
| """Configuration for training settings""" |
| output_dir: str = "./trained_legal_model" |
| num_epochs: int = 3 |
| batch_size: int = 2 |
| learning_rate: float = 5e-5 |
| warmup_steps: int = 500 |
| save_steps: int = 1000 |
| eval_steps: int = 1000 |
| max_grad_norm: float = 1.0 |
| gradient_accumulation_steps: int = 4 |
| weight_decay: float = 0.01 |
|
|
| @dataclass |
| class DataConfig: |
| """Configuration for data processing""" |
| dataset_path: str = "dataset-all.csv" |
| max_text_length: int = 512 |
| test_size: float = 0.1 |
| random_state: int = 42 |
|
|
| class ConfigManager: |
| """Manage configuration for the legal NLP system""" |
| |
| def __init__(self, config_path: str = "config.json"): |
| self.config_path = config_path |
| self.model_config = ModelConfig() |
| self.training_config = TrainingConfig() |
| self.data_config = DataConfig() |
| |
| self.load_config() |
| |
| def load_config(self): |
| """Load configuration from JSON file if exists""" |
| if os.path.exists(self.config_path): |
| try: |
| with open(self.config_path, 'r') as f: |
| config_data = json.load(f) |
| |
| |
| if 'model' in config_data: |
| for key, value in config_data['model'].items(): |
| if hasattr(self.model_config, key): |
| setattr(self.model_config, key, value) |
| |
| if 'training' in config_data: |
| for key, value in config_data['training'].items(): |
| if hasattr(self.training_config, key): |
| setattr(self.training_config, key, value) |
| |
| if 'data' in config_data: |
| for key, value in config_data['data'].items(): |
| if hasattr(self.data_config, key): |
| setattr(self.data_config, key, value) |
| |
| logging.info(f"Configuration loaded from {self.config_path}") |
| |
| except Exception as e: |
| logging.warning(f"Could not load config from {self.config_path}: {e}") |
| |
| def save_config(self): |
| """Save current configuration to JSON file""" |
| config_data = { |
| 'model': self.model_config.__dict__, |
| 'training': self.training_config.__dict__, |
| 'data': self.data_config.__dict__ |
| } |
| |
| try: |
| with open(self.config_path, 'w') as f: |
| json.dump(config_data, f, indent=2) |
| |
| logging.info(f"Configuration saved to {self.config_path}") |
| |
| except Exception as e: |
| logging.error(f"Could not save config to {self.config_path}: {e}") |
|
|
| |
| KINYARWANDA_STOPWORDS = { |
| 'ni', 'na', 'ku', 'mu', 'nk', 'no', 'cyangwa', 'ariko', 'naho', 'none', |
| 'kandi', 'rero', 'ubwo', 'uko', 'ubu', 'aha', 'aho', 'iyo', 'ese', |
| 'nta', 'nti', 'nte', 'nto', 'ntu', 'ntw', 'aba', 'ari', 'hari', |
| 'kuri', 'muri', 'buri', 'abantu', 'umuntu', 'ibintu', 'ikintu' |
| } |
|
|
| |
| KINYARWANDA_LEGAL_TERMS = { |
| 'gusambanya': 'sexual defilement', |
| 'kwiba': 'theft', |
| 'gukoresha_imbaraga': 'use of force/violence', |
| 'kwinjira': 'enter/trespass', |
| 'kwica': 'kill/murder', |
| 'gukubita': 'assault/beat', |
| 'uburiganya': 'fraud/deception', |
| 'ubuhemu': 'embezzlement', |
| 'igifungo': 'imprisonment', |
| 'ihazabu': 'fine', |
| 'igihano': 'punishment', |
| 'ingingo': 'article', |
| 'itegeko': 'law', |
| 'umwana': 'child', |
| 'imyaka': 'years', |
| 'amezi': 'months', |
| 'burundu': 'life (imprisonment)', |
| 'gahato': 'force/violence', |
| 'imibonano_mpuzabitsina': 'sexual intercourse', |
| 'inyamaswa': 'animals', |
| 'rugo': 'home/house' |
| } |
|
|
| def clean_kinyarwanda_text(text: str) -> str: |
| """Clean and normalize Kinyarwanda text""" |
| import re |
| |
| if not text or pd.isna(text): |
| return "" |
| |
| text = str(text) |
| |
| |
| text = re.sub(r'\s+', ' ', text) |
| |
| |
| text = re.sub(r'[^\w\s\-\.\,\;\:\!\?\'\"\u00C0-\u017F]', '', text) |
| |
| |
| text = re.sub(r'\b[0-9]+\b', ' NUMBER ', text) |
| text = re.sub(r'\bFRW\s*[0-9,\.]+\b', ' AMOUNT ', text) |
| text = re.sub(r'\b[0-9]+-[0-9]+\b', ' RANGE ', text) |
| |
| return text.strip() |
|
|
| def extract_keywords_kinyarwanda(text: str, max_keywords: int = 10) -> list: |
| """Extract keywords from Kinyarwanda text""" |
| if not text: |
| return [] |
| |
| |
| words = clean_kinyarwanda_text(text).lower().split() |
| |
| |
| keywords = [ |
| word for word in words |
| if len(word) > 2 and word not in KINYARWANDA_STOPWORDS |
| ] |
| |
| |
| from collections import Counter |
| word_counts = Counter(keywords) |
| |
| return [word for word, count in word_counts.most_common(max_keywords)] |
|
|
| |
| LEGAL_CATEGORIES = { |
| 'sexual_offence': { |
| 'english': ['sexual', 'rape', 'assault', 'child', 'defilement'], |
| 'kinyarwanda': ['gukoresha', 'gusambanya', 'igitsina', 'umwana', 'imibonano'] |
| }, |
| 'theft': { |
| 'english': ['theft', 'robbery', 'stealing', 'property'], |
| 'kinyarwanda': ['kwiba', 'gufata', 'umutungo', 'imbaraga'] |
| }, |
| 'privacy': { |
| 'english': ['privacy', 'domicile', 'recording', 'entry'], |
| 'kinyarwanda': ['kwinjira', 'kumviriza', 'rugo', 'ubuzima_bwite'] |
| }, |
| 'morality': { |
| 'english': ['adultery', 'bigamy', 'concubinage', 'marriage'], |
| 'kinyarwanda': ['ubusambanyi', 'ubushoreke', 'gushyingirwa', 'guta_urugo'] |
| }, |
| 'violence': { |
| 'english': ['violence', 'murder', 'genocide', 'torture', 'assault'], |
| 'kinyarwanda': ['kwica', 'gukubita', 'jenoside', 'ihohotera', 'imbaraga'] |
| }, |
| 'fraud': { |
| 'english': ['fraud', 'forgery', 'deception', 'embezzlement'], |
| 'kinyarwanda': ['uburiganya', 'kwigana', 'ubuhemu', 'kwibeshya'] |
| } |
| } |
|
|
| def categorize_case(description: str) -> str: |
| """Categorize a case based on description keywords (English and Kinyarwanda)""" |
| if not description: |
| return "unknown" |
| |
| description_lower = description.lower() |
| |
| category_scores = {} |
| |
| for category, terms in LEGAL_CATEGORIES.items(): |
| score = 0 |
| |
| for keyword in terms['english']: |
| if keyword in description_lower: |
| score += 2 |
| |
| |
| for keyword in terms['kinyarwanda']: |
| if keyword in description_lower: |
| score += 3 |
| |
| if score > 0: |
| category_scores[category] = score |
| |
| if category_scores: |
| return max(category_scores, key=category_scores.get) |
| else: |
| return "general" |
|
|
| |
| def calculate_similarity_score(text1: str, text2: str) -> float: |
| """Calculate simple similarity score between two texts""" |
| if not text1 or not text2: |
| return 0.0 |
| |
| words1 = set(clean_kinyarwanda_text(text1).lower().split()) |
| words2 = set(clean_kinyarwanda_text(text2).lower().split()) |
| |
| if not words1 or not words2: |
| return 0.0 |
| |
| intersection = len(words1.intersection(words2)) |
| union = len(words1.union(words2)) |
| |
| return intersection / union if union > 0 else 0.0 |
|
|
| |
| def save_predictions(predictions: list, output_path: str): |
| """Save predictions to file""" |
| try: |
| with open(output_path, 'w', encoding='utf-8') as f: |
| json.dump(predictions, f, indent=2, ensure_ascii=False) |
| |
| logging.info(f"Predictions saved to {output_path}") |
| |
| except Exception as e: |
| logging.error(f"Could not save predictions: {e}") |
|
|
| def load_predictions(input_path: str) -> list: |
| """Load predictions from file""" |
| try: |
| with open(input_path, 'r', encoding='utf-8') as f: |
| predictions = json.load(f) |
| |
| logging.info(f"Predictions loaded from {input_path}") |
| return predictions |
| |
| except Exception as e: |
| logging.error(f"Could not load predictions: {e}") |
| return [] |
|
|
| def format_punishment(punishment: str) -> dict: |
| """Parse and format punishment information""" |
| if not punishment or pd.isna(punishment): |
| return {"type": "unknown", "details": ""} |
| |
| punishment = str(punishment).lower() |
| |
| result = { |
| "type": "unknown", |
| "imprisonment": None, |
| "fine": None, |
| "community_service": None, |
| "details": punishment |
| } |
| |
| |
| import re |
| |
| |
| year_pattern = r'(\d+)(?:-(\d+))?\s*(?:years?|imyaka)' |
| year_match = re.search(year_pattern, punishment) |
| if year_match: |
| min_years = int(year_match.group(1)) |
| max_years = int(year_match.group(2)) if year_match.group(2) else min_years |
| result["imprisonment"] = f"{min_years}-{max_years} years" |
| result["type"] = "imprisonment" |
| |
| |
| if any(term in punishment for term in ['life', 'burundu', 'cya burundu']): |
| result["imprisonment"] = "Life imprisonment" |
| result["type"] = "life_imprisonment" |
| |
| |
| fine_pattern = r'frw\s*([0-9,\.]+)(?:\s*-\s*([0-9,\.]+))?' |
| fine_match = re.search(fine_pattern, punishment) |
| if fine_match: |
| min_fine = fine_match.group(1) |
| max_fine = fine_match.group(2) if fine_match.group(2) else min_fine |
| result["fine"] = f"FRW {min_fine}-{max_fine}" |
| |
| |
| if 'community service' in punishment or 'inyungu rusange' in punishment: |
| result["community_service"] = "Yes" |
| |
| return result |
|
|
| |
| config_manager = ConfigManager() |
|
|
| if __name__ == "__main__": |
| |
| sample_text = "Umuntu yakoreye umwana igikorwa gishingiye ku gitsina" |
| |
| print("Sample text:", sample_text) |
| print("Cleaned:", clean_kinyarwanda_text(sample_text)) |
| print("Keywords:", extract_keywords_kinyarwanda(sample_text)) |
| print("Category:", categorize_case(sample_text)) |
| |
| sample_punishment = "Igifungo cy'imyaka 10-15 + ihazabu FRW 1,000,000-2,000,000" |
| print("Formatted punishment:", format_punishment(sample_punishment)) |