mbaza / config.py
mugwaneza's picture
Deploy Kinyarwanda Legal Assistant with Docker
2979b34
"""
Configuration and utility functions for Rwanda Legal NLP System
"""
import os
import json
import logging
import pandas as pd
from typing import Dict, Any
from dataclasses import dataclass
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
@dataclass
class ModelConfig:
"""Configuration for model settings"""
model_name: str = "EleutherAI/gpt-j-6B"
max_length: int = 512
temperature: float = 0.7
top_p: float = 0.9
top_k: int = 50
do_sample: bool = True
num_return_sequences: int = 1
@dataclass
class TrainingConfig:
"""Configuration for training settings"""
output_dir: str = "./trained_legal_model"
num_epochs: int = 3
batch_size: int = 2
learning_rate: float = 5e-5
warmup_steps: int = 500
save_steps: int = 1000
eval_steps: int = 1000
max_grad_norm: float = 1.0
gradient_accumulation_steps: int = 4
weight_decay: float = 0.01
@dataclass
class DataConfig:
"""Configuration for data processing"""
dataset_path: str = "dataset-all.csv"
max_text_length: int = 512
test_size: float = 0.1
random_state: int = 42
class ConfigManager:
"""Manage configuration for the legal NLP system"""
def __init__(self, config_path: str = "config.json"):
self.config_path = config_path
self.model_config = ModelConfig()
self.training_config = TrainingConfig()
self.data_config = DataConfig()
self.load_config()
def load_config(self):
"""Load configuration from JSON file if exists"""
if os.path.exists(self.config_path):
try:
with open(self.config_path, 'r') as f:
config_data = json.load(f)
# Update configurations
if 'model' in config_data:
for key, value in config_data['model'].items():
if hasattr(self.model_config, key):
setattr(self.model_config, key, value)
if 'training' in config_data:
for key, value in config_data['training'].items():
if hasattr(self.training_config, key):
setattr(self.training_config, key, value)
if 'data' in config_data:
for key, value in config_data['data'].items():
if hasattr(self.data_config, key):
setattr(self.data_config, key, value)
logging.info(f"Configuration loaded from {self.config_path}")
except Exception as e:
logging.warning(f"Could not load config from {self.config_path}: {e}")
def save_config(self):
"""Save current configuration to JSON file"""
config_data = {
'model': self.model_config.__dict__,
'training': self.training_config.__dict__,
'data': self.data_config.__dict__
}
try:
with open(self.config_path, 'w') as f:
json.dump(config_data, f, indent=2)
logging.info(f"Configuration saved to {self.config_path}")
except Exception as e:
logging.error(f"Could not save config to {self.config_path}: {e}")
# Kinyarwanda language utilities
KINYARWANDA_STOPWORDS = {
'ni', 'na', 'ku', 'mu', 'nk', 'no', 'cyangwa', 'ariko', 'naho', 'none',
'kandi', 'rero', 'ubwo', 'uko', 'ubu', 'aha', 'aho', 'iyo', 'ese',
'nta', 'nti', 'nte', 'nto', 'ntu', 'ntw', 'aba', 'ari', 'hari',
'kuri', 'muri', 'buri', 'abantu', 'umuntu', 'ibintu', 'ikintu'
}
# Kinyarwanda legal terminology
KINYARWANDA_LEGAL_TERMS = {
'gusambanya': 'sexual defilement',
'kwiba': 'theft',
'gukoresha_imbaraga': 'use of force/violence',
'kwinjira': 'enter/trespass',
'kwica': 'kill/murder',
'gukubita': 'assault/beat',
'uburiganya': 'fraud/deception',
'ubuhemu': 'embezzlement',
'igifungo': 'imprisonment',
'ihazabu': 'fine',
'igihano': 'punishment',
'ingingo': 'article',
'itegeko': 'law',
'umwana': 'child',
'imyaka': 'years',
'amezi': 'months',
'burundu': 'life (imprisonment)',
'gahato': 'force/violence',
'imibonano_mpuzabitsina': 'sexual intercourse',
'inyamaswa': 'animals',
'rugo': 'home/house'
}
def clean_kinyarwanda_text(text: str) -> str:
"""Clean and normalize Kinyarwanda text"""
import re
if not text or pd.isna(text):
return ""
text = str(text)
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text)
# Keep Kinyarwanda specific characters
text = re.sub(r'[^\w\s\-\.\,\;\:\!\?\'\"\u00C0-\u017F]', '', text)
# Remove common noise patterns
text = re.sub(r'\b[0-9]+\b', ' NUMBER ', text)
text = re.sub(r'\bFRW\s*[0-9,\.]+\b', ' AMOUNT ', text)
text = re.sub(r'\b[0-9]+-[0-9]+\b', ' RANGE ', text)
return text.strip()
def extract_keywords_kinyarwanda(text: str, max_keywords: int = 10) -> list:
"""Extract keywords from Kinyarwanda text"""
if not text:
return []
# Simple keyword extraction
words = clean_kinyarwanda_text(text).lower().split()
# Filter out stopwords and short words
keywords = [
word for word in words
if len(word) > 2 and word not in KINYARWANDA_STOPWORDS
]
# Count frequency and return most common
from collections import Counter
word_counts = Counter(keywords)
return [word for word, count in word_counts.most_common(max_keywords)]
# Legal category mappings with Kinyarwanda terms
LEGAL_CATEGORIES = {
'sexual_offence': {
'english': ['sexual', 'rape', 'assault', 'child', 'defilement'],
'kinyarwanda': ['gukoresha', 'gusambanya', 'igitsina', 'umwana', 'imibonano']
},
'theft': {
'english': ['theft', 'robbery', 'stealing', 'property'],
'kinyarwanda': ['kwiba', 'gufata', 'umutungo', 'imbaraga']
},
'privacy': {
'english': ['privacy', 'domicile', 'recording', 'entry'],
'kinyarwanda': ['kwinjira', 'kumviriza', 'rugo', 'ubuzima_bwite']
},
'morality': {
'english': ['adultery', 'bigamy', 'concubinage', 'marriage'],
'kinyarwanda': ['ubusambanyi', 'ubushoreke', 'gushyingirwa', 'guta_urugo']
},
'violence': {
'english': ['violence', 'murder', 'genocide', 'torture', 'assault'],
'kinyarwanda': ['kwica', 'gukubita', 'jenoside', 'ihohotera', 'imbaraga']
},
'fraud': {
'english': ['fraud', 'forgery', 'deception', 'embezzlement'],
'kinyarwanda': ['uburiganya', 'kwigana', 'ubuhemu', 'kwibeshya']
}
}
def categorize_case(description: str) -> str:
"""Categorize a case based on description keywords (English and Kinyarwanda)"""
if not description:
return "unknown"
description_lower = description.lower()
category_scores = {}
for category, terms in LEGAL_CATEGORIES.items():
score = 0
# Check English terms
for keyword in terms['english']:
if keyword in description_lower:
score += 2
# Check Kinyarwanda terms
for keyword in terms['kinyarwanda']:
if keyword in description_lower:
score += 3 # Higher weight for Kinyarwanda terms
if score > 0:
category_scores[category] = score
if category_scores:
return max(category_scores, key=category_scores.get)
else:
return "general"
# Evaluation metrics
def calculate_similarity_score(text1: str, text2: str) -> float:
"""Calculate simple similarity score between two texts"""
if not text1 or not text2:
return 0.0
words1 = set(clean_kinyarwanda_text(text1).lower().split())
words2 = set(clean_kinyarwanda_text(text2).lower().split())
if not words1 or not words2:
return 0.0
intersection = len(words1.intersection(words2))
union = len(words1.union(words2))
return intersection / union if union > 0 else 0.0
# Utility functions
def save_predictions(predictions: list, output_path: str):
"""Save predictions to file"""
try:
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(predictions, f, indent=2, ensure_ascii=False)
logging.info(f"Predictions saved to {output_path}")
except Exception as e:
logging.error(f"Could not save predictions: {e}")
def load_predictions(input_path: str) -> list:
"""Load predictions from file"""
try:
with open(input_path, 'r', encoding='utf-8') as f:
predictions = json.load(f)
logging.info(f"Predictions loaded from {input_path}")
return predictions
except Exception as e:
logging.error(f"Could not load predictions: {e}")
return []
def format_punishment(punishment: str) -> dict:
"""Parse and format punishment information"""
if not punishment or pd.isna(punishment):
return {"type": "unknown", "details": ""}
punishment = str(punishment).lower()
result = {
"type": "unknown",
"imprisonment": None,
"fine": None,
"community_service": None,
"details": punishment
}
# Extract imprisonment
import re
# Years
year_pattern = r'(\d+)(?:-(\d+))?\s*(?:years?|imyaka)'
year_match = re.search(year_pattern, punishment)
if year_match:
min_years = int(year_match.group(1))
max_years = int(year_match.group(2)) if year_match.group(2) else min_years
result["imprisonment"] = f"{min_years}-{max_years} years"
result["type"] = "imprisonment"
# Life imprisonment
if any(term in punishment for term in ['life', 'burundu', 'cya burundu']):
result["imprisonment"] = "Life imprisonment"
result["type"] = "life_imprisonment"
# Fine amounts
fine_pattern = r'frw\s*([0-9,\.]+)(?:\s*-\s*([0-9,\.]+))?'
fine_match = re.search(fine_pattern, punishment)
if fine_match:
min_fine = fine_match.group(1)
max_fine = fine_match.group(2) if fine_match.group(2) else min_fine
result["fine"] = f"FRW {min_fine}-{max_fine}"
# Community service
if 'community service' in punishment or 'inyungu rusange' in punishment:
result["community_service"] = "Yes"
return result
# Default configuration instance
config_manager = ConfigManager()
if __name__ == "__main__":
# Test utilities
sample_text = "Umuntu yakoreye umwana igikorwa gishingiye ku gitsina"
print("Sample text:", sample_text)
print("Cleaned:", clean_kinyarwanda_text(sample_text))
print("Keywords:", extract_keywords_kinyarwanda(sample_text))
print("Category:", categorize_case(sample_text))
sample_punishment = "Igifungo cy'imyaka 10-15 + ihazabu FRW 1,000,000-2,000,000"
print("Formatted punishment:", format_punishment(sample_punishment))