Rogendo's picture
Create README.md
5ebdf14 verified
---
language:
- en
- sw
tags:
- multi-task-learning
- text-classification
- fraud-detection
- sentiment-analysis
- call-quality
- question-answering
- jenga-ai
- nlp-for-africa
- security
- attention-fusion
base_model: distilbert-base-uncased
license: apache-2.0
pipeline_tag: text-classification
datasets:
- custom
model-index:
- name: JengaAI-multi-task-nlp
results:
- task:
type: text-classification
name: Fraud Detection
metrics:
- type: f1
value: 1
name: F1
- type: accuracy
value: 1
name: Accuracy
- task:
type: text-classification
name: Sentiment Analysis
metrics:
- type: f1
value: 0.167
name: F1
- type: accuracy
value: 0.333
name: Accuracy
- task:
type: text-classification
name: Call Quality - Listening
metrics:
- type: f1
value: 0.922
name: F1
- task:
type: text-classification
name: Call Quality - Resolution
metrics:
- type: f1
value: 0.908
name: F1
widget:
- text: >-
Suspicious M-Pesa transaction detected from unknown account requesting
urgent transfer
example_title: Fraud Detection
- text: >-
The customer service was excellent, my billing issue was resolved on the
first call
example_title: Positive Sentiment
- text: Hello, welcome to Safaricom customer care. How can I assist you today?
example_title: Call Quality Scoring
library_name: transformers
---
# JengaAI Multi-Task NLP (3-Task Attention Fusion)
A **multi-task NLP model** built with the [JengaAI framework](https://github.com/Rogendo/JengaAI) that performs **fraud detection**, **sentiment analysis**, and **call quality scoring** simultaneously through a shared encoder with attention-based task fusion. Designed for Kenyan national security and telecommunications applications.
## Model Capabilities
This model handles **3 tasks** with **8 prediction heads** producing **22 total output dimensions** in a single forward pass:
| Task | Type | Heads | Outputs | Best F1 |
|:-----|:-----|:------|:--------|:--------|
| **Fraud Detection** | Binary classification | 1 (fraud) | 2 classes: normal / fraud | **1.000** |
| **Sentiment Analysis** | 3-class classification | 1 (sentiment) | 3 classes: negative / neutral / positive | 0.167 |
| **Call Quality Scoring** | Multi-label QA | 6 heads, 17 sub-metrics | Binary per sub-metric | **0.646 - 0.967** |
### Call Quality Sub-Metrics (17 Binary Outputs)
The call quality task evaluates customer service transcripts across 6 quality dimensions:
| Head | Sub-Metrics | F1 |
|:-----|:-----------|:---|
| **Opening** | greeting | 0.967 |
| **Listening** | acknowledgment, empathy, clarification, active_listening, patience | 0.922 |
| **Proactiveness** | initiative, follow_up, suggestions | 0.802 |
| **Resolution** | identified_issue, provided_solution, confirmed_resolution, set_expectations, offered_alternatives | 0.908 |
| **Hold** | asked_permission, explained_reason | 0.647 |
| **Closing** | proper_farewell | 0.881 |
## Architecture
```
Input Text
|
v
[DistilBERT Encoder] ---- 6 layers, 768 hidden, 12 attention heads
|
v
[Attention Fusion] ------- task-conditioned attention with residual connections
|
+-- [Task 0: Fraud Head] ----------- Linear(768, 2) --> softmax
+-- [Task 1: Sentiment Head] ------- Linear(768, 3) --> softmax
+-- [Task 2: QA Scoring 6 Heads] --- 6x Linear(768, 1..5) --> sigmoid
```
**Key design choices:**
- **Shared encoder**: All 3 tasks share a single DistilBERT encoder, enabling knowledge transfer between fraud patterns, sentiment signals, and call quality indicators
- **Attention fusion**: A learned attention mechanism modulates the shared representation per task, allowing each task to attend to different parts of the encoder output while still benefiting from shared features
- **Residual connections**: Fusion output is added to the original representation (gate_init_value=0.5), ensuring stable training and allowing each task to fall back on the base representation
- **Multi-head QA**: Call quality uses 6 independent classification heads with different output sizes (1-5 binary outputs each), weighted by importance during training (resolution: 2.0x, listening: 1.5x, hold: 0.5x)
## Usage
### With JengaAI Framework (Recommended)
```bash
pip install torch transformers pydantic pyyaml huggingface_hub
```
```python
from huggingface_hub import snapshot_download
from jenga_ai.inference import InferencePipeline
# Download model
model_path = snapshot_download(
"Rogendo/JengaAI-multi-task-nlp",
ignore_patterns=["checkpoints/*", "logs/*"],
)
# Load pipeline
pipeline = InferencePipeline.from_checkpoint(
model_dir=model_path,
config_path=f"{model_path}/experiment_config.yaml",
device="auto",
)
# Run all 3 tasks at once
result = pipeline.predict("Suspicious M-Pesa transaction from unknown account")
print(result.to_json())
# Or run a single task
fraud_result = pipeline.predict(
"WARNING: Your Safaricom account has been compromised. Send 5000 KES to unlock.",
task_name="fraud_detection",
)
fraud = fraud_result.task_results["fraud_detection"].heads["fraud"]
print(f"Fraud: {fraud.prediction} (confidence: {fraud.confidence:.1%})")
# Fraud: 1 (confidence: 96.9%)
```
### Batch Inference
```python
texts = [
"Suspicious M-Pesa notification asking me to send money.",
"Normal airtime top-up of 100 KES via M-Pesa.",
"WARNING: Your account has been compromised.",
]
results = pipeline.predict_batch(texts, task_name="fraud_detection", batch_size=32)
for text, result in zip(texts, results):
fraud = result.task_results["fraud_detection"].heads["fraud"]
label = "FRAUD" if fraud.prediction == 1 else "LEGIT"
print(f"[{label} {fraud.confidence:.1%}] {text}")
```
### CLI
```bash
# Single text
python -m jenga_ai predict \
--config experiment_config.yaml \
--model-dir ./model \
--text "Suspicious M-Pesa transaction from unknown account" \
--format report
# Batch from file
python -m jenga_ai predict \
--config experiment_config.yaml \
--model-dir ./model \
--input-file transcripts.jsonl \
--output predictions.json \
--batch-size 16
```
### Call Quality Scoring Example
```python
result = pipeline.predict(
"Hello, welcome to Safaricom customer care. I understand you're having "
"a billing issue. Let me look into that for you right away. I've found "
"the discrepancy and corrected your balance. Is there anything else?",
task_name="call_quality",
)
for head_name, head in result.task_results["call_quality"].heads.items():
print(f"{head_name:16s} {head.prediction} (conf: {head.confidence:.2f})")
```
Output:
```
opening {'greeting': True} (conf: 0.82)
listening {'acknowledgment': True, 'empathy': True, ...} (conf: 0.75)
proactiveness {'initiative': True, 'follow_up': True, 'suggestions': False} (conf: 0.58)
resolution {'identified_issue': True, 'provided_solution': True, ...} (conf: 0.69)
hold {'asked_permission': False, 'explained_reason': False} (conf: 0.02)
closing {'proper_farewell': True} (conf: 0.52)
```
### Low-Level Usage (Without JengaAI Framework)
If you only need the raw model weights and want to integrate into your own pipeline:
```python
import torch
import json
from transformers import AutoTokenizer, AutoModel, AutoConfig
# Load components
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
encoder_config = AutoConfig.from_pretrained("./model/encoder_config")
with open("./model/metadata.json") as f:
metadata = json.load(f)
# Load full state dict
state_dict = torch.load("./model/model.pt", map_location="cpu", weights_only=True)
# Extract encoder weights (keys starting with "encoder.")
encoder_state = {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")}
encoder = AutoModel.from_config(encoder_config)
encoder.load_state_dict(encoder_state)
encoder.eval()
# Run encoder
inputs = tokenizer("Suspicious transaction", return_tensors="pt", padding="max_length",
truncation=True, max_length=256)
with torch.no_grad():
outputs = encoder(**inputs)
cls_embedding = outputs.last_hidden_state[:, 0] # [1, 768]
# Extract fraud head weights (task 0, head "fraud")
fraud_weight = state_dict["tasks.0.heads.fraud.1.weight"] # [2, 768]
fraud_bias = state_dict["tasks.0.heads.fraud.1.bias"] # [2]
logits = cls_embedding @ fraud_weight.T + fraud_bias
probs = torch.softmax(logits, dim=-1)
print(f"Fraud probability: {probs[0, 1].item():.4f}")
```
## Intended Use
### Primary Use Cases
- **M-Pesa Fraud Detection**: Classify M-Pesa transaction descriptions as fraudulent or legitimate. Designed for Safaricom and Kenyan mobile money contexts.
- **Customer Sentiment Monitoring**: Analyze customer feedback and communications for sentiment polarity (negative / neutral / positive).
- **Call Center Quality Assurance**: Score customer service call transcripts across 17 quality sub-metrics in 6 categories, replacing manual QA audits.
- **Multi-Signal Analysis**: Run all 3 tasks simultaneously on the same text to get a comprehensive analysis (is this a fraud attempt? what's the sentiment? how good was the agent's response?).
### Intended Users
- Kenyan telecommunications companies (Safaricom, Airtel Kenya)
- Financial institutions monitoring mobile money transactions
- Call center operations teams performing quality audits
- Security analysts processing incident reports
- NLP researchers working on African language and context models
### Downstream Use
The model can be integrated into:
- Real-time fraud alerting systems
- Call center dashboards with automated QA scoring
- Customer feedback analysis pipelines
- Security operations center (SOC) threat triage workflows
- Mobile money transaction monitoring platforms
## Out-of-Scope Use
- **Not for automated decision-making without human oversight.** This model should support human analysts, not replace them. High-stakes fraud decisions require human review.
- **Not for non-Kenyan contexts without retraining.** Entity names, transaction patterns, and call center norms are Kenyan-specific.
- **Not for languages other than English.** While some Swahili words appear in the training data (M-Pesa, Safaricom, KRA), the model is primarily English.
- **Not for legal evidence.** Model outputs are analytical signals, not forensic evidence.
- **Not for surveillance of individuals.** The model analyzes text content, not identity.
## Bias, Risks, and Limitations
### Known Biases
- **Training data imbalance**: Fraud detection was trained on only 20 samples (16 train / 4 eval). The model achieves 1.0 F1 on eval but this is likely due to the tiny eval set and potential overfitting. Real-world fraud patterns are far more diverse.
- **Sentiment data**: Only 15 samples, with accuracy stuck at 33.3% (random baseline for 3 classes). The sentiment head needs significantly more training data to be production-useful.
- **Call quality data**: 4,996 synthetic transcripts. While metrics are strong (0.65-0.97 F1), the synthetic nature means real-world transcripts with noise, code-switching (Swahili-English), and non-standard grammar may perform differently.
- **Geographic bias**: All training data reflects Kenyan contexts. The model may not generalize to other East African countries without adaptation.
### Risks
- **False positives in fraud detection**: Legitimate transactions flagged as fraud can block real users. Always use this model with human review for enforcement actions.
- **False negatives in fraud detection**: Sophisticated fraud patterns not in the training data will be missed. This model is one signal among many, not a standalone detector.
- **Over-reliance on QA scores**: Call quality scores should augment, not replace, human QA reviewers. Edge cases (cultural nuances, sarcasm, escalation scenarios) may be scored incorrectly.
### Recommendations
- Use fraud detection as a **triage signal** (flag for review), not an automatic block
- Retrain with production-scale data before deploying to production
- Monitor prediction confidence — route low-confidence predictions to human review using the built-in HITL routing (`enable_hitl=True`)
- Enable PII redaction (`enable_pii=True`) when processing real customer data
- Enable audit logging (`enable_audit=True`) for compliance and accountability
## Training Details
### Training Data
| Dataset | Task | Samples | Source |
|:--------|:-----|:--------|:-------|
| `sample_classification.jsonl` | Fraud Detection | 20 | Synthetic M-Pesa transaction descriptions |
| `sample_sentiment.jsonl` | Sentiment Analysis | 15 | Synthetic customer feedback |
| `synthetic_qa_metrics_data_v01x.json` | Call Quality | 4,996 | Synthetic call center transcripts with 17 binary QA labels |
**Train/eval split**: 80/20 random split (seed=42)
All datasets are synthetic, generated to reflect linguistic patterns in Kenyan telecommunications and financial services contexts. They contain English text with occasional Swahili terms and Kenyan-specific entities (M-Pesa, Safaricom, KRA, Kenyan phone numbers).
### Training Procedure
#### Preprocessing
- Tokenizer: `distilbert-base-uncased` WordPiece tokenizer
- Max sequence length: 256 tokens
- Padding: `max_length` (padded to 256)
- Truncation: enabled
#### Architecture
- **Encoder**: DistilBERT (6 layers, 768 hidden, 12 heads) — 66.4M parameters
- **Fusion**: Attention fusion with residual connections — 1.2M parameters
- **Task heads**: 8 linear heads across 3 tasks — 17K parameters
- **Total**: 67.6M parameters (258MB on disk)
#### Training Hyperparameters
| Parameter | Value |
|:----------|:------|
| Learning rate | 2e-5 |
| Batch size | 16 |
| Epochs | 12 (best checkpoint at epoch 3) |
| Weight decay | 0.01 |
| Warmup steps | 20 |
| Max gradient norm | 1.0 |
| Optimizer | AdamW |
| Precision | FP32 |
| Task sampling | Proportional (temperature=2.0) |
| Early stopping patience | 5 epochs |
| Best model metric | eval_loss |
#### Task Loss Weights
| Head | Weight | Rationale |
|:-----|:-------|:----------|
| fraud | 1.0 | Standard |
| sentiment | 1.0 | Standard |
| opening | 1.0 | Standard |
| listening | 1.5 | Important quality dimension |
| proactiveness | 1.0 | Standard |
| resolution | 2.0 | Most critical quality dimension |
| hold | 0.5 | Less frequent in transcripts |
| closing | 1.0 | Standard |
#### Training Loss Progression
| Epoch | Train Loss | Eval Loss | Status |
|:------|:-----------|:----------|:-------|
| 3 | 1.878 | **1.948** | Best checkpoint |
| 7 | 1.471 | 2.057 | Overfitting begins |
| 8 | 1.403 | 2.068 | Continued overfitting |
The best checkpoint was selected at epoch 3 based on eval_loss. Training continued to epoch 12 but eval loss increased after epoch 3, indicating overfitting — expected given the small fraud and sentiment datasets.
### Speeds, Sizes, Times
| Metric | Value |
|:-------|:------|
| Model size (disk) | 258 MB |
| Parameters | 67.6M |
| Inference latency (single task, CPU) | ~590 ms |
| Inference latency (all 3 tasks, CPU) | ~1,960 ms |
| Batch throughput (32 texts, single task, CPU) | ~647 ms/sample |
| Training time | ~5 minutes (CPU, 12 epochs) |
## Evaluation
### Metrics
All metrics are computed on the 20% held-out eval split.
**Fraud Detection** (binary classification):
| Metric | Value |
|:-------|:------|
| Accuracy | 1.000 |
| Precision | 1.000 |
| Recall | 1.000 |
| F1 | 1.000 |
**Sentiment Analysis** (3-class classification):
| Metric | Value |
|:-------|:------|
| Accuracy | 0.333 |
| Precision | 0.111 |
| Recall | 0.333 |
| F1 | 0.167 |
**Call Quality** (multi-label binary per head):
| Head | Precision | Recall | F1 |
|:-----|:----------|:-------|:---|
| Opening | 0.967 | 0.967 | **0.967** |
| Listening | 0.893 | 0.953 | **0.922** |
| Proactiveness | 0.746 | 0.868 | **0.802** |
| Resolution | 0.918 | 0.898 | **0.908** |
| Hold | 0.856 | 0.519 | **0.647** |
| Closing | 0.881 | 0.881 | **0.881** |
### Results Summary
- **Fraud detection** achieves perfect metrics on the eval set, but this is a very small eval set (4 samples). Production deployment requires evaluation on a larger, more diverse dataset.
- **Sentiment analysis** performs at random baseline (33.3% accuracy for 3 classes), indicating the 15-sample dataset is insufficient. This head needs retraining with production data.
- **Call quality** shows strong performance across most heads (0.80-0.97 F1), with the "hold" category being the weakest (0.647 F1) due to fewer hold-related examples in the training data.
## Model Examination
### Attention Fusion
The attention fusion mechanism learns task-specific attention patterns over the shared encoder output. This allows:
- The fraud head to attend to transaction-related tokens (amounts, account references)
- The sentiment head to attend to opinion-bearing words
- The QA heads to attend to conversational flow patterns
The fusion uses a gated residual connection (initialized at 0.5), meaning each task's representation is a learned blend of the task-specific attended output and the original encoder output.
### Security Features
When used with the JengaAI inference framework, the model supports:
- **PII Redaction**: Masks Kenyan-specific PII (phone numbers, national IDs, KRA PINs, M-Pesa transaction IDs) before inference
- **Explainability**: Token-level importance scores via attention analysis or gradient methods
- **Human-in-the-Loop**: Automatic routing of low-confidence predictions to human reviewers based on entropy-based uncertainty estimation
- **Audit Trail**: Tamper-evident logging of every inference call with SHA-256 hash chains
## Technical Specifications
### Model Architecture and Objective
- **Architecture**: DistilBERT encoder + attention fusion + multi-task heads
- **Encoder**: 6 transformer layers, 768 hidden size, 12 attention heads, 30,522 vocab
- **Fusion**: Single-head attention with residual gating
- **Objectives**: CrossEntropy (fraud, sentiment) + BCEWithLogits (call quality)
### Compute Infrastructure
#### Hardware
- Training: CPU (Intel/AMD, standard workstation)
- Inference: CPU or CUDA GPU
#### Software
- PyTorch 2.x
- Transformers 5.x
- JengaAI Framework V2
- Python 3.11+
## Environmental Impact
- **Hardware Type**: CPU (standard workstation)
- **Training Time**: ~5 minutes
- **Carbon Emitted**: Negligible (short training run on CPU)
## Citation
```bibtex
@software{jengaai2026,
title = {JengaAI: Low-Code Multi-Task NLP for African Security Applications},
author = {Rogendo},
year = {2026},
url = {https://huggingface.co/Rogendo/JengaAI-multi-task-nlp},
}
```
## Model Card Authors
Rogendo
## Model Card Contact
For questions, issues, or contributions: [GitHub Issues](https://github.com/Rogendo/JengaAI/issues)