Tomas
commited on
Add initial project setup with model configuration, requirements, and upload script
Browse files- Heaven1-guardian.png +0 -0
- README.md +109 -3
- check_torch.py +56 -0
- config.yaml +46 -0
- create_dataset.py +251 -0
- data/heaven_dataset.jsonl +0 -0
- finetune_heaven.py +331 -0
- fix_nms_error.py +144 -0
- model_card.md +156 -0
- requirements.txt +18 -0
- run_heaven.py +112 -0
- upload_to_hub.py +94 -0
Heaven1-guardian.png
ADDED
|
README.md
CHANGED
|
@@ -1,3 +1,109 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Heaven1-base: Guardian
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
Heaven1-base (codename: "Guardian") is a specialized AI model fine-tuned from Llama 3.2 to detect predatory behavior in text messages. Designed as a protective tool, Guardian analyzes conversations to identify potentially harmful interactions, making digital spaces safer for vulnerable individuals.
|
| 8 |
+
|
| 9 |
+
The model has been trained to recognize various tactics commonly employed by predators, including:
|
| 10 |
+
|
| 11 |
+
- Grooming language and manipulation
|
| 12 |
+
- Attempts to isolate victims from support networks
|
| 13 |
+
- Requests for personal information or images
|
| 14 |
+
- Attempts to move conversations to more private platforms
|
| 15 |
+
- Emotional manipulation tactics
|
| 16 |
+
- Inappropriate sexual content
|
| 17 |
+
|
| 18 |
+
## Technical Details
|
| 19 |
+
|
| 20 |
+
- **Base Model**: Meta-Llama-3.2-8B-Instruct
|
| 21 |
+
- **Training Method**: Parameter-Efficient Fine-Tuning (PEFT) using QLoRA
|
| 22 |
+
- **Training Dataset**: Carefully crafted synthetic dataset representing various predatory conversation patterns
|
| 23 |
+
- **Task**: Text message analysis and predatory behavior detection with detailed explanations
|
| 24 |
+
|
| 25 |
+
## Usage
|
| 26 |
+
|
| 27 |
+
### Input Format
|
| 28 |
+
|
| 29 |
+
The model expects input in the following format:
|
| 30 |
+
|
| 31 |
+
```
|
| 32 |
+
<|system|>
|
| 33 |
+
You are Heaven, an AI designed to detect predatory behavior in text messages. Analyze the following message and determine if it contains predatory behavior. Provide a detailed explanation for your assessment.
|
| 34 |
+
<|user|>
|
| 35 |
+
[TEXT MESSAGE TO ANALYZE]
|
| 36 |
+
<|assistant|>
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### Output Format
|
| 40 |
+
|
| 41 |
+
The model will respond with a detection result and detailed explanation:
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
PREDATORY BEHAVIOR DETECTED. This message contains multiple warning signs: (1) [Warning Sign 1], (2) [Warning Sign 2], etc. These are common tactics used by predators to manipulate potential victims.
|
| 45 |
+
|
| 46 |
+
OR
|
| 47 |
+
|
| 48 |
+
NO PREDATORY BEHAVIOR DETECTED. This message contains normal friendly communication. [Additional context about the message]. There are no attempts at manipulation, isolation, inappropriate requests, or other warning signs of predatory behavior.
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Example Usage with Transformers
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 55 |
+
|
| 56 |
+
# Load model and tokenizer
|
| 57 |
+
model_path = "safecircleia/heaven1-base"
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 59 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 60 |
+
|
| 61 |
+
# Message to analyze
|
| 62 |
+
message_to_analyze = "Hey, I know we just met but I feel like we have a special connection. Don't tell your parents about our chats, they wouldn't understand. Can you send me a picture of yourself?"
|
| 63 |
+
|
| 64 |
+
# Format the prompt
|
| 65 |
+
prompt = f"""<|system|>
|
| 66 |
+
You are Heaven, an AI designed to detect predatory behavior in text messages. Analyze the following message and determine if it contains predatory behavior. Provide a detailed explanation for your assessment.
|
| 67 |
+
<|user|>
|
| 68 |
+
{message_to_analyze}
|
| 69 |
+
<|assistant|>
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
# Generate response
|
| 73 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 74 |
+
outputs = model.generate(inputs["input_ids"], max_new_tokens=512)
|
| 75 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 76 |
+
|
| 77 |
+
print(response)
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Ethical Considerations
|
| 81 |
+
|
| 82 |
+
- This model is designed as a protective tool to help identify potentially harmful communication patterns.
|
| 83 |
+
- False positives and false negatives are possible; human review should be employed for critical applications.
|
| 84 |
+
- The model should be used as part of a broader safety framework, not as the sole decision-maker.
|
| 85 |
+
- Privacy and consent are essential when analyzing communications.
|
| 86 |
+
|
| 87 |
+
## Limitations
|
| 88 |
+
|
| 89 |
+
- The model detects patterns based on its training data and may miss novel predatory tactics.
|
| 90 |
+
- Cultural and contextual nuances may impact detection accuracy.
|
| 91 |
+
- The model is not a substitute for human judgment in safeguarding vulnerable individuals.
|
| 92 |
+
|
| 93 |
+
## Citation
|
| 94 |
+
|
| 95 |
+
If you use Heaven1-base Guardian in your research or applications, please cite:
|
| 96 |
+
|
| 97 |
+
```
|
| 98 |
+
@misc{heaven1-base-2025,
|
| 99 |
+
author = {SafeCircleIA},
|
| 100 |
+
title = {Heaven1-base: Guardian - Predatory Behavior Detection Model},
|
| 101 |
+
year = {2024},
|
| 102 |
+
publisher = {Hugging Face},
|
| 103 |
+
howpublished = {\url{https://huggingface.co/safecircleia/heaven1-base-guardian}}
|
| 104 |
+
}
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## Contact
|
| 108 |
+
|
| 109 |
+
For questions, feedback, or concerns about the Heaven1-base Guardian model, please contact SafeCircleIA through Hugging Face or via contact@safecircle.tech.
|
check_torch.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import sys
|
| 4 |
+
import importlib
|
| 5 |
+
|
| 6 |
+
def check_installations():
|
| 7 |
+
"""Check PyTorch and torchvision installations."""
|
| 8 |
+
print("Python version:", sys.version)
|
| 9 |
+
print("PyTorch version:", torch.__version__)
|
| 10 |
+
print("torchvision version:", torchvision.__version__)
|
| 11 |
+
print("CUDA available:", torch.cuda.is_available())
|
| 12 |
+
if torch.cuda.is_available():
|
| 13 |
+
print("CUDA version:", torch.version.cuda)
|
| 14 |
+
print("cuDNN version:", torch.backends.cudnn.version())
|
| 15 |
+
|
| 16 |
+
# Check if PyTorch and torchvision versions are compatible
|
| 17 |
+
torch_version = torch.__version__.split('.')
|
| 18 |
+
torchvision_version = torchvision.__version__.split('.')
|
| 19 |
+
|
| 20 |
+
if torch_version[0] != torchvision_version[0] or torch_version[1] != torchvision_version[1]:
|
| 21 |
+
print("WARNING: PyTorch and torchvision versions might be incompatible!")
|
| 22 |
+
print("It's recommended to have matching major and minor version numbers.")
|
| 23 |
+
|
| 24 |
+
# Check for NMS operator
|
| 25 |
+
try:
|
| 26 |
+
print("\nAttempting to import torchvision.ops...")
|
| 27 |
+
from torchvision.ops import nms
|
| 28 |
+
print("Successfully imported NMS operator.")
|
| 29 |
+
|
| 30 |
+
# Create dummy data to test NMS
|
| 31 |
+
boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]], dtype=torch.float32)
|
| 32 |
+
scores = torch.tensor([0.9, 0.8], dtype=torch.float32)
|
| 33 |
+
|
| 34 |
+
print("Testing NMS functionality...")
|
| 35 |
+
indices = nms(boxes, scores, 0.5)
|
| 36 |
+
print("NMS test successful! Result:", indices)
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"Error importing or using NMS: {e}")
|
| 39 |
+
|
| 40 |
+
# Check for dependencies that might use NMS
|
| 41 |
+
print("\nChecking dependencies that might use NMS operator...")
|
| 42 |
+
deps_to_check = [
|
| 43 |
+
'trl', 'transformers', 'peft', 'accelerate',
|
| 44 |
+
'bitsandbytes', 'datasets'
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
for dep in deps_to_check:
|
| 48 |
+
try:
|
| 49 |
+
module = importlib.import_module(dep)
|
| 50 |
+
version = getattr(module, '__version__', 'unknown')
|
| 51 |
+
print(f"✓ {dep} version: {version}")
|
| 52 |
+
except ImportError:
|
| 53 |
+
print(f"✗ {dep} not installed")
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
check_installations()
|
config.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Heaven Model Configuration for Llama 3.2 Fine-tuning
|
| 2 |
+
|
| 3 |
+
# Dataset configuration
|
| 4 |
+
dataset:
|
| 5 |
+
size: 10000 # Number of examples to generate
|
| 6 |
+
predatory_ratio: 0.5 # Ratio of predatory examples (0-1)
|
| 7 |
+
output_path: "data/heaven_dataset.jsonl"
|
| 8 |
+
|
| 9 |
+
# Model configuration
|
| 10 |
+
model:
|
| 11 |
+
name_or_path: "meta-llama/Llama-3.2-3B-Instruct" # HuggingFace model identifier
|
| 12 |
+
output_dir: "./heaven1-base-8b" # Directory to save fine-tuned model
|
| 13 |
+
|
| 14 |
+
# Training configuration
|
| 15 |
+
training:
|
| 16 |
+
num_epochs: 3 # Number of training epochs
|
| 17 |
+
batch_size: 1 # Batch size per device
|
| 18 |
+
gradient_accumulation_steps: 8 # Number of steps to accumulate gradients
|
| 19 |
+
learning_rate: 2e-5 # Initial learning rate
|
| 20 |
+
weight_decay: 0.01 # Weight decay coefficient
|
| 21 |
+
max_grad_norm: 1.0 # Max gradient norm for clipping
|
| 22 |
+
warmup_ratio: 0.1 # Linear warmup ratio
|
| 23 |
+
eval_ratio: 0.1 # Portion of data used for evaluation
|
| 24 |
+
max_seq_length: 4096 # Maximum sequence length
|
| 25 |
+
|
| 26 |
+
# PEFT configuration (Parameter-Efficient Fine-Tuning)
|
| 27 |
+
peft:
|
| 28 |
+
use_lora: true # Whether to use LoRA
|
| 29 |
+
use_qlora: true # Whether to use QLoRA (quantized LoRA)
|
| 30 |
+
lora_r: 16 # LoRA rank
|
| 31 |
+
lora_alpha: 32 # LoRA scaling factor
|
| 32 |
+
lora_dropout: 0.05 # LoRA dropout rate
|
| 33 |
+
|
| 34 |
+
# Precision configuration
|
| 35 |
+
precision:
|
| 36 |
+
fp16: false # Whether to use fp16 mixed precision
|
| 37 |
+
bf16: true # Whether to use bf16 mixed precision
|
| 38 |
+
compute_dtype: "float16" # Compute dtype for quantization
|
| 39 |
+
|
| 40 |
+
# Logging configuration
|
| 41 |
+
logging:
|
| 42 |
+
use_wandb: false # Whether to use Weights & Biases
|
| 43 |
+
run_name: "heaven-llama3-2" # Name of the run
|
| 44 |
+
logging_steps: 10 # Steps between logging
|
| 45 |
+
eval_steps: 100 # Steps between evaluation
|
| 46 |
+
save_steps: 100 # Steps between saving checkpoints
|
create_dataset.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
import argparse
|
| 6 |
+
import jsonlines
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
# Define common patterns for predatory behavior
|
| 10 |
+
PREDATORY_PATTERNS = {
|
| 11 |
+
"grooming": [
|
| 12 |
+
"You seem so mature for your age",
|
| 13 |
+
"I feel like we have a special connection",
|
| 14 |
+
"You're not like other kids/people your age",
|
| 15 |
+
"I understand you better than your parents/friends",
|
| 16 |
+
"You can trust me with anything",
|
| 17 |
+
],
|
| 18 |
+
"isolation": [
|
| 19 |
+
"Don't tell your parents about our conversations",
|
| 20 |
+
"This should be just between us",
|
| 21 |
+
"Your friends wouldn't understand our relationship",
|
| 22 |
+
"You don't need to tell anyone about this",
|
| 23 |
+
"Let's keep this our secret",
|
| 24 |
+
],
|
| 25 |
+
"personal_info": [
|
| 26 |
+
"Where do you live exactly?",
|
| 27 |
+
"Are you home alone right now?",
|
| 28 |
+
"What's your address?",
|
| 29 |
+
"When are your parents usually out?",
|
| 30 |
+
"What school do you go to?",
|
| 31 |
+
],
|
| 32 |
+
"image_requests": [
|
| 33 |
+
"Send me a picture of yourself",
|
| 34 |
+
"I'd love to see what you look like",
|
| 35 |
+
"Can you send a photo?",
|
| 36 |
+
"Do you have any pictures you could share?",
|
| 37 |
+
"Send me a pic",
|
| 38 |
+
],
|
| 39 |
+
"private_platform": [
|
| 40 |
+
"Let's talk on Snapchat instead",
|
| 41 |
+
"Do you have WhatsApp? We should chat there",
|
| 42 |
+
"Give me your Instagram, we can DM",
|
| 43 |
+
"Let's move this conversation somewhere more private",
|
| 44 |
+
"What's your Discord? It's better to chat there",
|
| 45 |
+
],
|
| 46 |
+
"emotional_manipulation": [
|
| 47 |
+
"If you really cared about me, you would...",
|
| 48 |
+
"I thought you were different, but I guess I was wrong",
|
| 49 |
+
"After everything I've done for you...",
|
| 50 |
+
"No one has ever made me feel this way before",
|
| 51 |
+
"You're making me sad by not trusting me",
|
| 52 |
+
],
|
| 53 |
+
"sexual_content": [
|
| 54 |
+
"Have you ever kissed anyone?",
|
| 55 |
+
"Are you comfortable talking about intimate things?",
|
| 56 |
+
"Do you know what [explicit term] means?",
|
| 57 |
+
"Have you ever seen [explicit content]?",
|
| 58 |
+
"I could teach you about [suggestive content]",
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Define normal conversation patterns
|
| 63 |
+
NORMAL_PATTERNS = {
|
| 64 |
+
"greetings": [
|
| 65 |
+
"Hey, how are you doing today?",
|
| 66 |
+
"Hi there! How's your day going?",
|
| 67 |
+
"Hello! How have you been?",
|
| 68 |
+
"Good morning! Ready for the day?",
|
| 69 |
+
"Hey, just checking in to see how you're doing",
|
| 70 |
+
],
|
| 71 |
+
"casual_plans": [
|
| 72 |
+
"Want to grab coffee sometime?",
|
| 73 |
+
"Are you going to the event this weekend?",
|
| 74 |
+
"We should hang out at the mall with the group",
|
| 75 |
+
"Do you want to see that new movie?",
|
| 76 |
+
"Let's meet up at the park with everyone",
|
| 77 |
+
],
|
| 78 |
+
"interests": [
|
| 79 |
+
"What kind of music do you listen to?",
|
| 80 |
+
"Have you watched any good shows lately?",
|
| 81 |
+
"What are your hobbies?",
|
| 82 |
+
"Did you see the game last night?",
|
| 83 |
+
"What books are you reading these days?",
|
| 84 |
+
],
|
| 85 |
+
"school_work": [
|
| 86 |
+
"How are your classes going?",
|
| 87 |
+
"Did you finish that assignment?",
|
| 88 |
+
"Do you understand the math homework?",
|
| 89 |
+
"I'm struggling with this project, any advice?",
|
| 90 |
+
"Are you ready for the test tomorrow?",
|
| 91 |
+
],
|
| 92 |
+
"support": [
|
| 93 |
+
"I'm here if you need to talk",
|
| 94 |
+
"Hope things get better soon",
|
| 95 |
+
"Let me know if you need anything",
|
| 96 |
+
"That sounds tough, how are you handling it?",
|
| 97 |
+
"I believe in you, you can do this",
|
| 98 |
+
]
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
def generate_predatory_message() -> str:
|
| 102 |
+
"""Generate a synthetic predatory message with multiple red flags."""
|
| 103 |
+
message_parts = []
|
| 104 |
+
|
| 105 |
+
# Select 2-4 pattern categories at random
|
| 106 |
+
categories = random.sample(list(PREDATORY_PATTERNS.keys()), random.randint(2, 4))
|
| 107 |
+
|
| 108 |
+
# Add a greeting sometimes
|
| 109 |
+
if random.random() < 0.7:
|
| 110 |
+
message_parts.append(random.choice(NORMAL_PATTERNS["greetings"]))
|
| 111 |
+
|
| 112 |
+
# Add predatory patterns
|
| 113 |
+
for category in categories:
|
| 114 |
+
message_parts.append(random.choice(PREDATORY_PATTERNS[category]))
|
| 115 |
+
|
| 116 |
+
# Sometimes mix in normal conversation to make it less obvious
|
| 117 |
+
if random.random() < 0.5:
|
| 118 |
+
normal_category = random.choice(list(NORMAL_PATTERNS.keys()))
|
| 119 |
+
message_parts.append(random.choice(NORMAL_PATTERNS[normal_category]))
|
| 120 |
+
|
| 121 |
+
# Shuffle the parts to create a more natural conversation
|
| 122 |
+
if len(message_parts) > 2: # Keep greeting first if it exists
|
| 123 |
+
first_part = message_parts[0] if random.random() < 0.7 else ""
|
| 124 |
+
remaining_parts = message_parts[1:] if first_part else message_parts
|
| 125 |
+
random.shuffle(remaining_parts)
|
| 126 |
+
if first_part:
|
| 127 |
+
message_parts = [first_part] + remaining_parts
|
| 128 |
+
else:
|
| 129 |
+
message_parts = remaining_parts
|
| 130 |
+
|
| 131 |
+
return " ".join(message_parts)
|
| 132 |
+
|
| 133 |
+
def generate_normal_message() -> str:
|
| 134 |
+
"""Generate a synthetic normal message."""
|
| 135 |
+
message_parts = []
|
| 136 |
+
|
| 137 |
+
# Select 2-3 pattern categories at random
|
| 138 |
+
categories = random.sample(list(NORMAL_PATTERNS.keys()), random.randint(2, 3))
|
| 139 |
+
|
| 140 |
+
for category in categories:
|
| 141 |
+
message_parts.append(random.choice(NORMAL_PATTERNS[category]))
|
| 142 |
+
|
| 143 |
+
return " ".join(message_parts)
|
| 144 |
+
|
| 145 |
+
def generate_predatory_explanation(message: str) -> str:
|
| 146 |
+
"""Generate an explanation for why a message is predatory."""
|
| 147 |
+
explanation = "PREDATORY BEHAVIOR DETECTED. This message contains multiple warning signs: "
|
| 148 |
+
warning_signs = []
|
| 149 |
+
|
| 150 |
+
for category, patterns in PREDATORY_PATTERNS.items():
|
| 151 |
+
for pattern in patterns:
|
| 152 |
+
if pattern.lower() in message.lower():
|
| 153 |
+
if category == "grooming":
|
| 154 |
+
warning_signs.append(f"Grooming language ('{pattern}')")
|
| 155 |
+
elif category == "isolation":
|
| 156 |
+
warning_signs.append(f"Isolation attempt ('{pattern}')")
|
| 157 |
+
elif category == "personal_info":
|
| 158 |
+
warning_signs.append(f"Seeking personal information ('{pattern}')")
|
| 159 |
+
elif category == "image_requests":
|
| 160 |
+
warning_signs.append(f"Request for images ('{pattern}')")
|
| 161 |
+
elif category == "private_platform":
|
| 162 |
+
warning_signs.append(f"Attempting to move to private communication ('{pattern}')")
|
| 163 |
+
elif category == "emotional_manipulation":
|
| 164 |
+
warning_signs.append(f"Emotional manipulation ('{pattern}')")
|
| 165 |
+
elif category == "sexual_content":
|
| 166 |
+
warning_signs.append(f"Inappropriate sexual content ('{pattern}')")
|
| 167 |
+
|
| 168 |
+
for i, sign in enumerate(warning_signs):
|
| 169 |
+
if i == 0:
|
| 170 |
+
explanation += f"(1) {sign}"
|
| 171 |
+
else:
|
| 172 |
+
explanation += f", ({i+1}) {sign}"
|
| 173 |
+
|
| 174 |
+
explanation += ". These are common tactics used by predators to manipulate potential victims."
|
| 175 |
+
|
| 176 |
+
return explanation
|
| 177 |
+
|
| 178 |
+
def generate_normal_explanation(message: str) -> str:
|
| 179 |
+
"""Generate an explanation for why a message is not predatory."""
|
| 180 |
+
explanation = "NO PREDATORY BEHAVIOR DETECTED. This message contains normal friendly communication. "
|
| 181 |
+
|
| 182 |
+
if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["greetings"]):
|
| 183 |
+
explanation += "It includes a casual greeting. "
|
| 184 |
+
if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["casual_plans"]):
|
| 185 |
+
explanation += "It contains appropriate social plans. "
|
| 186 |
+
if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["interests"]):
|
| 187 |
+
explanation += "It shows interest in common topics. "
|
| 188 |
+
if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["school_work"]):
|
| 189 |
+
explanation += "It discusses school or work. "
|
| 190 |
+
if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["support"]):
|
| 191 |
+
explanation += "It offers appropriate support. "
|
| 192 |
+
|
| 193 |
+
explanation += "There are no attempts at manipulation, isolation, inappropriate requests, or other warning signs of predatory behavior."
|
| 194 |
+
|
| 195 |
+
return explanation
|
| 196 |
+
|
| 197 |
+
def create_dataset_entry(predatory: bool = False) -> Dict[str, List[Dict[str, str]]]:
|
| 198 |
+
"""Create a single dataset entry in the format required for Llama 3.2 fine-tuning."""
|
| 199 |
+
system_message = "You are Heaven, an AI designed to detect predatory behavior in text messages. Analyze the following message and determine if it contains predatory behavior. Provide a detailed explanation for your assessment."
|
| 200 |
+
|
| 201 |
+
if predatory:
|
| 202 |
+
user_message = generate_predatory_message()
|
| 203 |
+
assistant_message = generate_predatory_explanation(user_message)
|
| 204 |
+
else:
|
| 205 |
+
user_message = generate_normal_message()
|
| 206 |
+
assistant_message = generate_normal_explanation(user_message)
|
| 207 |
+
|
| 208 |
+
return {
|
| 209 |
+
"messages": [
|
| 210 |
+
{"role": "system", "content": system_message},
|
| 211 |
+
{"role": "user", "content": user_message},
|
| 212 |
+
{"role": "assistant", "content": assistant_message}
|
| 213 |
+
]
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
def create_dataset(size: int, predatory_ratio: float = 0.5, output_path: str = "heaven_dataset.jsonl"):
|
| 217 |
+
"""Create a full dataset with the specified size and ratio of predatory examples."""
|
| 218 |
+
predatory_count = int(size * predatory_ratio)
|
| 219 |
+
normal_count = size - predatory_count
|
| 220 |
+
|
| 221 |
+
dataset = []
|
| 222 |
+
|
| 223 |
+
print(f"Generating {predatory_count} predatory examples...")
|
| 224 |
+
for _ in tqdm(range(predatory_count)):
|
| 225 |
+
dataset.append(create_dataset_entry(predatory=True))
|
| 226 |
+
|
| 227 |
+
print(f"Generating {normal_count} normal examples...")
|
| 228 |
+
for _ in tqdm(range(normal_count)):
|
| 229 |
+
dataset.append(create_dataset_entry(predatory=False))
|
| 230 |
+
|
| 231 |
+
# Shuffle the dataset
|
| 232 |
+
random.shuffle(dataset)
|
| 233 |
+
|
| 234 |
+
# Save to JSONL file
|
| 235 |
+
with jsonlines.open(output_path, mode='w') as writer:
|
| 236 |
+
for entry in dataset:
|
| 237 |
+
writer.write(entry)
|
| 238 |
+
|
| 239 |
+
print(f"Dataset saved to {output_path}")
|
| 240 |
+
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
parser = argparse.ArgumentParser(description="Generate a synthetic dataset for predatory behavior detection")
|
| 243 |
+
parser.add_argument("--size", type=int, default=1000, help="Total size of the dataset")
|
| 244 |
+
parser.add_argument("--ratio", type=float, default=0.5, help="Ratio of predatory examples (0-1)")
|
| 245 |
+
parser.add_argument("--output", type=str, default="data/heaven_dataset.jsonl", help="Output file path")
|
| 246 |
+
args = parser.parse_args()
|
| 247 |
+
|
| 248 |
+
# Create output directory if it doesn't exist
|
| 249 |
+
os.makedirs(os.path.dirname(args.output), exist_ok=True)
|
| 250 |
+
|
| 251 |
+
create_dataset(args.size, args.ratio, args.output)
|
data/heaven_dataset.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
finetune_heaven.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForCausalLM,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
TrainingArguments,
|
| 11 |
+
Trainer,
|
| 12 |
+
DataCollatorForLanguageModeling,
|
| 13 |
+
BitsAndBytesConfig
|
| 14 |
+
)
|
| 15 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 16 |
+
from trl import SFTTrainer
|
| 17 |
+
import wandb
|
| 18 |
+
|
| 19 |
+
def format_prompt(examples):
|
| 20 |
+
"""
|
| 21 |
+
Format the prompts for Llama 3.2 instruction fine-tuning.
|
| 22 |
+
This function processes batches of examples and returns formatted strings.
|
| 23 |
+
"""
|
| 24 |
+
# Process each example in the batch
|
| 25 |
+
formatted_prompts = []
|
| 26 |
+
|
| 27 |
+
# Handle batch processing - if it's a single example, wrap it in a list
|
| 28 |
+
messages_list = examples["messages"] if isinstance(examples, dict) and "messages" in examples else [examples]
|
| 29 |
+
|
| 30 |
+
# Process each message in the batch
|
| 31 |
+
for messages in messages_list:
|
| 32 |
+
if isinstance(messages, list) and len(messages) >= 3:
|
| 33 |
+
system = messages[0]["content"] if isinstance(messages[0], dict) and "content" in messages[0] else ""
|
| 34 |
+
user_message = messages[1]["content"] if isinstance(messages[1], dict) and "content" in messages[1] else ""
|
| 35 |
+
assistant_message = messages[2]["content"] if isinstance(messages[2], dict) and "content" in messages[2] else ""
|
| 36 |
+
else:
|
| 37 |
+
# Fallback for unexpected structure
|
| 38 |
+
print(f"Warning: Unexpected message structure: {messages}")
|
| 39 |
+
system = ""
|
| 40 |
+
user_message = ""
|
| 41 |
+
assistant_message = ""
|
| 42 |
+
|
| 43 |
+
# Format the prompt
|
| 44 |
+
formatted = f"<|system|>\n{system}\n<|user|>\n{user_message}\n<|assistant|>\n{assistant_message}"
|
| 45 |
+
formatted_prompts.append(formatted)
|
| 46 |
+
|
| 47 |
+
return formatted_prompts
|
| 48 |
+
|
| 49 |
+
def preprocess_function(examples, tokenizer, max_length=4096):
|
| 50 |
+
"""
|
| 51 |
+
Tokenize the examples for training
|
| 52 |
+
"""
|
| 53 |
+
# Get formatted prompts
|
| 54 |
+
formatted_prompts = [format_prompt(example) for example in examples["messages"]]
|
| 55 |
+
|
| 56 |
+
# Tokenize
|
| 57 |
+
tokenized_inputs = tokenizer(
|
| 58 |
+
formatted_prompts,
|
| 59 |
+
padding="max_length",
|
| 60 |
+
truncation=True,
|
| 61 |
+
max_length=max_length,
|
| 62 |
+
return_tensors="pt",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Create labels (same as input_ids since we're doing causal LM training)
|
| 66 |
+
tokenized_inputs["labels"] = tokenized_inputs["input_ids"].clone()
|
| 67 |
+
|
| 68 |
+
return tokenized_inputs
|
| 69 |
+
|
| 70 |
+
def train(args):
|
| 71 |
+
print("Initializing training process...")
|
| 72 |
+
|
| 73 |
+
# Initialize wandb if tracking is enabled
|
| 74 |
+
if args.use_wandb:
|
| 75 |
+
print("Initializing Weights & Biases...")
|
| 76 |
+
wandb.init(project="heaven-llama3-2", name=args.run_name)
|
| 77 |
+
|
| 78 |
+
# Load the tokenizer
|
| 79 |
+
print("Loading tokenizer...")
|
| 80 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 81 |
+
args.model_name_or_path,
|
| 82 |
+
trust_remote_code=True,
|
| 83 |
+
)
|
| 84 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 85 |
+
tokenizer.padding_side = "right"
|
| 86 |
+
print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")
|
| 87 |
+
|
| 88 |
+
# Configure quantization if using QLoRA
|
| 89 |
+
print("Setting up model configuration...")
|
| 90 |
+
if args.use_qlora:
|
| 91 |
+
print("Configuring QLoRA quantization...")
|
| 92 |
+
compute_dtype = getattr(torch, args.compute_dtype)
|
| 93 |
+
quantization_config = BitsAndBytesConfig(
|
| 94 |
+
load_in_4bit=True,
|
| 95 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 96 |
+
bnb_4bit_use_double_quant=True,
|
| 97 |
+
bnb_4bit_quant_type="nf4",
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
quantization_config = None
|
| 101 |
+
|
| 102 |
+
# Load the model with use_cache=False for compatibility with gradient checkpointing
|
| 103 |
+
print(f"Loading model: {args.model_name_or_path}")
|
| 104 |
+
print(f"GPU available: {torch.cuda.is_available()}, Device count: {torch.cuda.device_count()}")
|
| 105 |
+
if torch.cuda.is_available():
|
| 106 |
+
for i in range(torch.cuda.device_count()):
|
| 107 |
+
print(f"GPU {i}: {torch.cuda.get_device_name(i)}, "
|
| 108 |
+
f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
|
| 109 |
+
|
| 110 |
+
start_time = time.time()
|
| 111 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 112 |
+
args.model_name_or_path,
|
| 113 |
+
quantization_config=quantization_config,
|
| 114 |
+
device_map="auto",
|
| 115 |
+
trust_remote_code=True,
|
| 116 |
+
use_cache=False # Set use_cache=False explicitly for gradient checkpointing
|
| 117 |
+
)
|
| 118 |
+
load_time = time.time() - start_time
|
| 119 |
+
print(f"Model loaded in {load_time:.2f} seconds")
|
| 120 |
+
|
| 121 |
+
# Prepare model for k-bit training if using QLoRA
|
| 122 |
+
if args.use_qlora:
|
| 123 |
+
print("Preparing model for k-bit training...")
|
| 124 |
+
model = prepare_model_for_kbit_training(model)
|
| 125 |
+
|
| 126 |
+
# Set up LoRA configuration
|
| 127 |
+
peft_config = None
|
| 128 |
+
if args.use_lora:
|
| 129 |
+
print("Setting up LoRA configuration...")
|
| 130 |
+
peft_config = LoraConfig(
|
| 131 |
+
r=args.lora_r,
|
| 132 |
+
lora_alpha=args.lora_alpha,
|
| 133 |
+
lora_dropout=args.lora_dropout,
|
| 134 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 135 |
+
bias="none",
|
| 136 |
+
task_type="CAUSAL_LM",
|
| 137 |
+
)
|
| 138 |
+
model = get_peft_model(model, peft_config)
|
| 139 |
+
model.print_trainable_parameters()
|
| 140 |
+
|
| 141 |
+
# Load and prepare dataset
|
| 142 |
+
print(f"Loading dataset from {args.dataset_path}...")
|
| 143 |
+
start_time = time.time()
|
| 144 |
+
dataset = load_dataset("json", data_files=args.dataset_path)
|
| 145 |
+
print(f"Dataset loaded in {time.time() - start_time:.2f} seconds. Size: {len(dataset['train'])} examples")
|
| 146 |
+
|
| 147 |
+
# Process the dataset manually to ensure correct formatting
|
| 148 |
+
print("Preprocessing dataset...")
|
| 149 |
+
start_time = time.time()
|
| 150 |
+
|
| 151 |
+
def preprocess_function(examples):
|
| 152 |
+
if len(examples["messages"]) % 100 == 0:
|
| 153 |
+
print(f"Processing batch of {len(examples['messages'])} examples...")
|
| 154 |
+
formatted_texts = format_prompt(examples)
|
| 155 |
+
return tokenizer(
|
| 156 |
+
formatted_texts,
|
| 157 |
+
padding="max_length",
|
| 158 |
+
truncation=True,
|
| 159 |
+
max_length=args.max_seq_length,
|
| 160 |
+
return_tensors=None # Return Python lists
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
processed_dataset = dataset["train"].map(
|
| 164 |
+
preprocess_function,
|
| 165 |
+
batched=True,
|
| 166 |
+
batch_size=100,
|
| 167 |
+
remove_columns=["messages"],
|
| 168 |
+
desc="Processing dataset"
|
| 169 |
+
)
|
| 170 |
+
print(f"Dataset preprocessing completed in {time.time() - start_time:.2f} seconds")
|
| 171 |
+
|
| 172 |
+
# Split dataset into train and evaluation
|
| 173 |
+
print(f"Splitting dataset with test_size={args.eval_ratio}...")
|
| 174 |
+
split_dataset = processed_dataset.train_test_split(test_size=args.eval_ratio)
|
| 175 |
+
print(f"Train set: {len(split_dataset['train'])} examples, Test set: {len(split_dataset['test'])} examples")
|
| 176 |
+
|
| 177 |
+
# Create a data collator for language modeling
|
| 178 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 179 |
+
tokenizer=tokenizer,
|
| 180 |
+
mlm=False
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Set up training arguments
|
| 184 |
+
print("Configuring training arguments...")
|
| 185 |
+
training_args = TrainingArguments(
|
| 186 |
+
output_dir=args.output_dir,
|
| 187 |
+
num_train_epochs=args.num_epochs,
|
| 188 |
+
per_device_train_batch_size=args.batch_size,
|
| 189 |
+
per_device_eval_batch_size=args.batch_size,
|
| 190 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 191 |
+
eval_strategy="steps", # Use newer parameter name
|
| 192 |
+
save_strategy="steps",
|
| 193 |
+
eval_steps=args.eval_steps,
|
| 194 |
+
save_steps=args.save_steps,
|
| 195 |
+
logging_steps=args.logging_steps,
|
| 196 |
+
learning_rate=args.learning_rate,
|
| 197 |
+
weight_decay=args.weight_decay,
|
| 198 |
+
fp16=args.fp16,
|
| 199 |
+
bf16=args.bf16,
|
| 200 |
+
max_grad_norm=args.max_grad_norm,
|
| 201 |
+
max_steps=args.max_steps,
|
| 202 |
+
warmup_ratio=args.warmup_ratio,
|
| 203 |
+
group_by_length=False,
|
| 204 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
| 205 |
+
report_to="wandb" if args.use_wandb else "none",
|
| 206 |
+
save_total_limit=3,
|
| 207 |
+
remove_unused_columns=False,
|
| 208 |
+
load_best_model_at_end=True,
|
| 209 |
+
metric_for_best_model="eval_loss",
|
| 210 |
+
# Add gradient checkpointing settings
|
| 211 |
+
gradient_checkpointing=True,
|
| 212 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}, # Explicitly set use_reentrant=False
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Create the SFT trainer
|
| 216 |
+
print("Initializing SFTTrainer...")
|
| 217 |
+
trainer = SFTTrainer(
|
| 218 |
+
model=model,
|
| 219 |
+
train_dataset=split_dataset["train"],
|
| 220 |
+
eval_dataset=split_dataset["test"],
|
| 221 |
+
args=training_args,
|
| 222 |
+
tokenizer=tokenizer,
|
| 223 |
+
# Remove formatting_func since we're pre-processing the dataset
|
| 224 |
+
max_seq_length=args.max_seq_length,
|
| 225 |
+
# Pass peft_config separately if using LoRA
|
| 226 |
+
peft_config=peft_config if args.use_lora else None,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Train the model
|
| 230 |
+
print("Starting training...")
|
| 231 |
+
print("If training appears stuck here, the trainer might be compiling the model or allocating memory.")
|
| 232 |
+
print("For large models, this can take several minutes, especially on the first training step.")
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
train_result = trainer.train()
|
| 236 |
+
|
| 237 |
+
# Save the model
|
| 238 |
+
print(f"Saving model to {args.output_dir}")
|
| 239 |
+
trainer.save_model(args.output_dir)
|
| 240 |
+
|
| 241 |
+
# Save training metrics
|
| 242 |
+
trainer.log_metrics("train", train_result.metrics)
|
| 243 |
+
trainer.save_metrics("train", train_result.metrics)
|
| 244 |
+
trainer.save_state()
|
| 245 |
+
|
| 246 |
+
# Evaluate model
|
| 247 |
+
print("Evaluating model...")
|
| 248 |
+
metrics = trainer.evaluate()
|
| 249 |
+
trainer.log_metrics("eval", metrics)
|
| 250 |
+
trainer.save_metrics("eval", metrics)
|
| 251 |
+
|
| 252 |
+
print(f"Training complete! Model saved to {args.output_dir}")
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"Error during training: {e}")
|
| 255 |
+
import traceback
|
| 256 |
+
traceback.print_exc()
|
| 257 |
+
|
| 258 |
+
# Close wandb if used
|
| 259 |
+
if args.use_wandb:
|
| 260 |
+
wandb.finish()
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
parser = argparse.ArgumentParser(description="Fine-tune Llama 3.2 for predatory behavior detection")
|
| 264 |
+
|
| 265 |
+
# Model and dataset arguments
|
| 266 |
+
parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3.2-8B-Instruct",
|
| 267 |
+
help="Path to pretrained model or model identifier from huggingface.co/models")
|
| 268 |
+
parser.add_argument("--dataset_path", type=str, required=True,
|
| 269 |
+
help="Path to the JSONL dataset for fine-tuning")
|
| 270 |
+
parser.add_argument("--output_dir", type=str, default="./heaven-model",
|
| 271 |
+
help="Directory to save the fine-tuned model")
|
| 272 |
+
|
| 273 |
+
# Training hyperparameters
|
| 274 |
+
parser.add_argument("--num_epochs", type=int, default=3,
|
| 275 |
+
help="Number of training epochs")
|
| 276 |
+
parser.add_argument("--batch_size", type=int, default=1,
|
| 277 |
+
help="Batch size per device")
|
| 278 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=8,
|
| 279 |
+
help="Number of updates steps to accumulate before performing a backward/update pass")
|
| 280 |
+
parser.add_argument("--learning_rate", type=float, default=2e-5,
|
| 281 |
+
help="Initial learning rate")
|
| 282 |
+
parser.add_argument("--weight_decay", type=float, default=0.01,
|
| 283 |
+
help="Weight decay to apply")
|
| 284 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0,
|
| 285 |
+
help="Max gradient norm for gradient clipping")
|
| 286 |
+
parser.add_argument("--max_steps", type=int, default=-1,
|
| 287 |
+
help="If > 0, set total number of training steps to perform. Overrides num_epochs")
|
| 288 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.1,
|
| 289 |
+
help="Linear warmup over warmup_ratio fraction of total steps")
|
| 290 |
+
parser.add_argument("--eval_ratio", type=float, default=0.1,
|
| 291 |
+
help="Ratio of data to use for evaluation")
|
| 292 |
+
parser.add_argument("--lr_scheduler_type", type=str, default="cosine",
|
| 293 |
+
help="Learning rate scheduler type")
|
| 294 |
+
parser.add_argument("--max_seq_length", type=int, default=4096,
|
| 295 |
+
help="Maximum sequence length for training")
|
| 296 |
+
|
| 297 |
+
# Logging and evaluation arguments
|
| 298 |
+
parser.add_argument("--logging_steps", type=int, default=10,
|
| 299 |
+
help="Number of steps between logging")
|
| 300 |
+
parser.add_argument("--eval_steps", type=int, default=100,
|
| 301 |
+
help="Number of steps between evaluations")
|
| 302 |
+
parser.add_argument("--save_steps", type=int, default=100,
|
| 303 |
+
help="Number of steps between saving model checkpoints")
|
| 304 |
+
parser.add_argument("--run_name", type=str, default="heaven-llama3-2",
|
| 305 |
+
help="Name of the run for logging")
|
| 306 |
+
parser.add_argument("--use_wandb", action="store_true",
|
| 307 |
+
help="Whether to use Weights & Biases for logging")
|
| 308 |
+
|
| 309 |
+
# PEFT arguments
|
| 310 |
+
parser.add_argument("--use_lora", action="store_true",
|
| 311 |
+
help="Whether to use LoRA for fine-tuning")
|
| 312 |
+
parser.add_argument("--use_qlora", action="store_true",
|
| 313 |
+
help="Whether to use QLoRA for fine-tuning (4-bit quantization with LoRA)")
|
| 314 |
+
parser.add_argument("--lora_r", type=int, default=16,
|
| 315 |
+
help="Rank of the LoRA update matrices")
|
| 316 |
+
parser.add_argument("--lora_alpha", type=int, default=32,
|
| 317 |
+
help="Scaling factor for LoRA")
|
| 318 |
+
parser.add_argument("--lora_dropout", type=float, default=0.05,
|
| 319 |
+
help="Dropout probability for LoRA")
|
| 320 |
+
|
| 321 |
+
# Mixed precision arguments
|
| 322 |
+
parser.add_argument("--fp16", action="store_true",
|
| 323 |
+
help="Whether to use fp16 mixed precision")
|
| 324 |
+
parser.add_argument("--bf16", action="store_true",
|
| 325 |
+
help="Whether to use bf16 mixed precision")
|
| 326 |
+
parser.add_argument("--compute_dtype", type=str, default="float16",
|
| 327 |
+
help="Compute dtype for 4-bit quantization")
|
| 328 |
+
|
| 329 |
+
args = parser.parse_args()
|
| 330 |
+
|
| 331 |
+
train(args)
|
fix_nms_error.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import importlib.util
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
def locate_trl_module():
|
| 7 |
+
"""Find the location of the TRL module in the Python path."""
|
| 8 |
+
try:
|
| 9 |
+
spec = importlib.util.find_spec('trl')
|
| 10 |
+
if spec is None:
|
| 11 |
+
print("TRL module not found in the Python path")
|
| 12 |
+
return None
|
| 13 |
+
|
| 14 |
+
trl_path = Path(spec.origin).parent
|
| 15 |
+
print(f"Found TRL module at: {trl_path}")
|
| 16 |
+
return trl_path
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"Error locating TRL module: {e}")
|
| 19 |
+
return None
|
| 20 |
+
|
| 21 |
+
def patch_sft_trainer():
|
| 22 |
+
"""Patch the SFTTrainer to avoid using torchvision's NMS operator."""
|
| 23 |
+
trl_path = locate_trl_module()
|
| 24 |
+
if trl_path is None:
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
# Path to the trainer.py file which likely contains the NMS reference
|
| 28 |
+
trainer_path = trl_path / "trainer" / "sft_trainer.py"
|
| 29 |
+
|
| 30 |
+
if not trainer_path.exists():
|
| 31 |
+
print(f"Could not find the SFT trainer file at: {trainer_path}")
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
print(f"Found SFT trainer file at: {trainer_path}")
|
| 35 |
+
|
| 36 |
+
# Read the file content
|
| 37 |
+
with open(trainer_path, "r") as f:
|
| 38 |
+
content = f.read()
|
| 39 |
+
|
| 40 |
+
# Check if 'torchvision' is in the file
|
| 41 |
+
if "torchvision" not in content:
|
| 42 |
+
print("No torchvision imports found in the SFT trainer file.")
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
# Create backup
|
| 46 |
+
backup_path = trainer_path.with_suffix(".py.bak")
|
| 47 |
+
print(f"Creating backup at: {backup_path}")
|
| 48 |
+
with open(backup_path, "w") as f:
|
| 49 |
+
f.write(content)
|
| 50 |
+
|
| 51 |
+
# Replace imports - common patterns
|
| 52 |
+
patched_content = content
|
| 53 |
+
|
| 54 |
+
# Pattern 1: Direct import of nms
|
| 55 |
+
patched_content = patched_content.replace(
|
| 56 |
+
"from torchvision.ops import nms",
|
| 57 |
+
"# from torchvision.ops import nms # Commented out to fix NMS error"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Pattern 2: Import torchvision
|
| 61 |
+
patched_content = patched_content.replace(
|
| 62 |
+
"import torchvision",
|
| 63 |
+
"# import torchvision # Commented out to fix NMS error"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Pattern 3: Import from torchvision.ops
|
| 67 |
+
patched_content = patched_content.replace(
|
| 68 |
+
"from torchvision.ops",
|
| 69 |
+
"# from torchvision.ops # Commented out to fix NMS error"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Add our custom NMS implementation
|
| 73 |
+
custom_nms = """
|
| 74 |
+
# Custom NMS implementation to avoid torchvision dependency
|
| 75 |
+
def nms(boxes, scores, iou_threshold):
|
| 76 |
+
"""
|
| 77 |
+
Performs non-maximum suppression (NMS) on the boxes according to their
|
| 78 |
+
intersection-over-union (IoU).
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
boxes (Tensor[N, 4]): boxes to perform NMS on
|
| 82 |
+
scores (Tensor[N]): scores for each one of the boxes
|
| 83 |
+
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Tensor: int64 tensor with the indices of the elements that have been kept
|
| 87 |
+
"""
|
| 88 |
+
import torch
|
| 89 |
+
|
| 90 |
+
# Sort boxes by scores
|
| 91 |
+
_, order = scores.sort(0, descending=True)
|
| 92 |
+
keep = []
|
| 93 |
+
|
| 94 |
+
while order.numel() > 0:
|
| 95 |
+
if order.numel() == 1:
|
| 96 |
+
keep.append(order.item())
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
i = order[0].item()
|
| 100 |
+
keep.append(i)
|
| 101 |
+
|
| 102 |
+
# Compute IoU of the remaining boxes with the largest box
|
| 103 |
+
xx1 = torch.max(boxes[i, 0], boxes[order[1:], 0])
|
| 104 |
+
yy1 = torch.max(boxes[i, 1], boxes[order[1:], 1])
|
| 105 |
+
xx2 = torch.min(boxes[i, 2], boxes[order[1:], 2])
|
| 106 |
+
yy2 = torch.min(boxes[i, 3], boxes[order[1:], 3])
|
| 107 |
+
|
| 108 |
+
w = torch.clamp(xx2 - xx1, min=0.0)
|
| 109 |
+
h = torch.clamp(yy2 - yy1, min=0.0)
|
| 110 |
+
inter = w * h
|
| 111 |
+
|
| 112 |
+
# IoU = intersection / (area1 + area2 - intersection)
|
| 113 |
+
box_area = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1])
|
| 114 |
+
other_area = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1])
|
| 115 |
+
iou = inter / (box_area + other_area - inter)
|
| 116 |
+
|
| 117 |
+
# Keep boxes with IoU less than threshold
|
| 118 |
+
inds = torch.where(iou <= iou_threshold)[0]
|
| 119 |
+
order = order[inds + 1]
|
| 120 |
+
|
| 121 |
+
return torch.tensor(keep, dtype=torch.int64)
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
# Add our custom implementation somewhere near the imports
|
| 125 |
+
import_end = patched_content.find("\n\n", patched_content.find("import "))
|
| 126 |
+
if import_end == -1:
|
| 127 |
+
import_end = patched_content.find("\n", patched_content.find("import "))
|
| 128 |
+
|
| 129 |
+
patched_content = patched_content[:import_end] + custom_nms + patched_content[import_end:]
|
| 130 |
+
|
| 131 |
+
# Write the patched file
|
| 132 |
+
with open(trainer_path, "w") as f:
|
| 133 |
+
f.write(patched_content)
|
| 134 |
+
|
| 135 |
+
print(f"Successfully patched {trainer_path}")
|
| 136 |
+
print("The SFTTrainer should now work without requiring torchvision's NMS operator")
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
if __name__ == "__main__":
|
| 140 |
+
success = patch_sft_trainer()
|
| 141 |
+
if success:
|
| 142 |
+
print("\nPatch applied successfully. You can now run the fine-tuning script.")
|
| 143 |
+
else:
|
| 144 |
+
print("\nFailed to apply the patch. Please check the error messages above.")
|
model_card.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: MIT
|
| 4 |
+
tags:
|
| 5 |
+
- llama
|
| 6 |
+
- llama-3.2
|
| 7 |
+
- safeguarding
|
| 8 |
+
- content-moderation
|
| 9 |
+
- safety
|
| 10 |
+
- predator-detection
|
| 11 |
+
- text-classification
|
| 12 |
+
datasets:
|
| 13 |
+
- safecircleia/heaven1-dataset
|
| 14 |
+
metrics:
|
| 15 |
+
- accuracy
|
| 16 |
+
- precision
|
| 17 |
+
- recall
|
| 18 |
+
- f1
|
| 19 |
+
pipeline_tag: text-classification
|
| 20 |
+
widget:
|
| 21 |
+
- text: "Hey, I know we just met but I feel like we have a special connection. Don't tell your parents about our chats, they wouldn't understand. Can you send me a picture of yourself?"
|
| 22 |
+
- text: "Hey, just checking in to see how your day went. Let me know if you want to grab coffee this weekend."
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
# Heaven1-base: Guardian
|
| 26 |
+
|
| 27 |
+
<img src="https://huggingface.co/safecircleia/heaven1-base/resolve/main/Heaven1-guardian.png" alt="Heaven1-base Guardian Banner" width="600">
|
| 28 |
+
|
| 29 |
+
Heaven1-base (codename: "Guardian") is a specialized AI model fine-tuned from Llama 3.2 to detect predatory behavior in text messages. The model analyzes conversation patterns to identify potential warning signs of grooming, manipulation, or predatory tactics.
|
| 30 |
+
|
| 31 |
+
## Model Details
|
| 32 |
+
|
| 33 |
+
- **Developed by:** SafeCircleIA
|
| 34 |
+
- **Model type:** Fine-tuned Llama 3.2
|
| 35 |
+
- **Language(s):** English
|
| 36 |
+
- **Base model:** meta-llama/Llama-3.2-8B-Instruct
|
| 37 |
+
- **Training approach:** Parameter-Efficient Fine-Tuning (PEFT) using QLoRA
|
| 38 |
+
- **License:** MIT
|
| 39 |
+
|
| 40 |
+
## Intended Use
|
| 41 |
+
|
| 42 |
+
This model is intended to serve as a protective tool for:
|
| 43 |
+
|
| 44 |
+
- Content moderation teams
|
| 45 |
+
- Platform safety engineers
|
| 46 |
+
- Organizations focused on child and vulnerable adult safety online
|
| 47 |
+
- Researchers studying digital safety and online predatory behavior
|
| 48 |
+
|
| 49 |
+
### Primary intended uses
|
| 50 |
+
|
| 51 |
+
- Detecting potentially harmful interactions in text messages
|
| 52 |
+
- Providing explanations for why certain messages contain predatory elements
|
| 53 |
+
- Assisting human moderators in identifying concerning patterns
|
| 54 |
+
- Supporting research on digital safety
|
| 55 |
+
|
| 56 |
+
### Primary intended users
|
| 57 |
+
|
| 58 |
+
- Content moderation teams
|
| 59 |
+
- Digital safety professionals
|
| 60 |
+
- Platform trust & safety teams
|
| 61 |
+
- Child protection services
|
| 62 |
+
- Safety-focused researchers
|
| 63 |
+
|
| 64 |
+
## How to Use
|
| 65 |
+
|
| 66 |
+
You can use the model via the Hugging Face `transformers` library:
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 70 |
+
|
| 71 |
+
# Load model and tokenizer
|
| 72 |
+
model_path = "safecircleia/heaven1-base"
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 74 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 75 |
+
|
| 76 |
+
# Message to analyze
|
| 77 |
+
message_to_analyze = "Hey, I know we just met but I feel like we have a special connection. Don't tell your parents about our chats, they wouldn't understand. Can you send me a picture of yourself?"
|
| 78 |
+
|
| 79 |
+
# Format the prompt
|
| 80 |
+
prompt = f"""<|system|>
|
| 81 |
+
You are Heaven, an AI designed to detect predatory behavior in text messages. Analyze the following message and determine if it contains predatory behavior. Provide a detailed explanation for your assessment.
|
| 82 |
+
<|user|>
|
| 83 |
+
{message_to_analyze}
|
| 84 |
+
<|assistant|>
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
# Generate response
|
| 88 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 89 |
+
outputs = model.generate(inputs["input_ids"], max_new_tokens=512)
|
| 90 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 91 |
+
|
| 92 |
+
print(response)
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## Training Details
|
| 96 |
+
|
| 97 |
+
- **Training data:** The model was trained on a carefully curated synthetic dataset representing various predatory conversation patterns
|
| 98 |
+
- **Training procedure:** Fine-tuned using QLoRA to adapt Llama 3.2's capabilities to predatory text detection
|
| 99 |
+
- **Hyperparameters:**
|
| 100 |
+
- LoRA rank: 16
|
| 101 |
+
- LoRA alpha: 32
|
| 102 |
+
- Learning rate: 2e-5
|
| 103 |
+
- Batch size: 1 with gradient accumulation steps of 8
|
| 104 |
+
- Training epochs: 3
|
| 105 |
+
- Maximum sequence length: 4096
|
| 106 |
+
|
| 107 |
+
## Evaluation Results
|
| 108 |
+
|
| 109 |
+
Performance metrics on test dataset:
|
| 110 |
+
|
| 111 |
+
| Metric | Score |
|
| 112 |
+
|--------|-------|
|
| 113 |
+
| Accuracy | [INSERT VALUE] |
|
| 114 |
+
| Precision | [INSERT VALUE] |
|
| 115 |
+
| Recall | [INSERT VALUE] |
|
| 116 |
+
| F1 | [INSERT VALUE] |
|
| 117 |
+
|
| 118 |
+
## Limitations & Biases
|
| 119 |
+
|
| 120 |
+
### Limitations
|
| 121 |
+
|
| 122 |
+
- The model detects patterns based on its training data and may miss novel predatory tactics
|
| 123 |
+
- Performance may vary across different cultural contexts and communication styles
|
| 124 |
+
- False positives and false negatives are possible; human review is recommended
|
| 125 |
+
|
| 126 |
+
### Recommendations
|
| 127 |
+
|
| 128 |
+
- Do not use as the sole decision-maker for safety-critical applications
|
| 129 |
+
- Always combine with human review for best results
|
| 130 |
+
- Consider cultural and contextual factors when interpreting results
|
| 131 |
+
- Regularly evaluate and update the model as predatory tactics evolve
|
| 132 |
+
|
| 133 |
+
## Ethical Considerations
|
| 134 |
+
|
| 135 |
+
This model is designed to help create safer digital environments, particularly for vulnerable individuals like children. However, it should be used responsibly:
|
| 136 |
+
|
| 137 |
+
- Respect privacy and obtain appropriate consent when analyzing communications
|
| 138 |
+
- Be transparent about the use of AI detection systems
|
| 139 |
+
- Do not use the model to create or refine predatory language patterns
|
| 140 |
+
- Consider the impact of false positives on legitimate communications
|
| 141 |
+
|
| 142 |
+
## Additional Information
|
| 143 |
+
|
| 144 |
+
For more information, research, or to contribute to the development of digital safety tools, visit [SafeCircleIA website or contact information].
|
| 145 |
+
|
| 146 |
+
## Citation
|
| 147 |
+
|
| 148 |
+
```
|
| 149 |
+
@misc{heaven1-base-2025,
|
| 150 |
+
author = {SafeCircleIA},
|
| 151 |
+
title = {Heaven1-base: Guardian - Predatory Behavior Detection Model},
|
| 152 |
+
year = {2024},
|
| 153 |
+
publisher = {Hugging Face},
|
| 154 |
+
howpublished = {\url{https://huggingface.co/safecircleia/heaven1-base-guardian}}
|
| 155 |
+
}
|
| 156 |
+
```
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.36.0
|
| 2 |
+
datasets>=2.14.0
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
torchvision>=0.15.0
|
| 5 |
+
accelerate>=0.20.0
|
| 6 |
+
bitsandbytes>=0.40.0
|
| 7 |
+
peft>=0.4.0
|
| 8 |
+
trl>=0.7.0
|
| 9 |
+
scipy>=1.10.0
|
| 10 |
+
numpy>=1.24.0
|
| 11 |
+
wandb>=0.15.0
|
| 12 |
+
scikit-learn>=1.2.0
|
| 13 |
+
tqdm>=4.65.0
|
| 14 |
+
jsonlines>=3.1.0
|
| 15 |
+
sentencepiece>=0.1.99
|
| 16 |
+
protobuf>=3.20.0
|
| 17 |
+
einops>=0.6.0
|
| 18 |
+
pyyaml>=6.0
|
run_heaven.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import yaml
|
| 4 |
+
import subprocess
|
| 5 |
+
|
| 6 |
+
def load_config(config_path):
|
| 7 |
+
"""Load configuration from YAML file."""
|
| 8 |
+
with open(config_path, "r") as f:
|
| 9 |
+
return yaml.safe_load(f)
|
| 10 |
+
|
| 11 |
+
def create_dataset(config):
|
| 12 |
+
"""Create dataset using the configuration."""
|
| 13 |
+
dataset_config = config["dataset"]
|
| 14 |
+
|
| 15 |
+
cmd = [
|
| 16 |
+
"python", "create_dataset.py",
|
| 17 |
+
"--size", str(dataset_config["size"]),
|
| 18 |
+
"--ratio", str(dataset_config["predatory_ratio"]),
|
| 19 |
+
"--output", dataset_config["output_path"]
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
print(f"Creating dataset with {dataset_config['size']} examples...")
|
| 23 |
+
result = subprocess.run(cmd)
|
| 24 |
+
if result.returncode != 0:
|
| 25 |
+
print("Dataset creation failed!")
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
print(f"Dataset created successfully at {dataset_config['output_path']}")
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
def finetune_model(config):
|
| 32 |
+
"""Fine-tune the model using the configuration."""
|
| 33 |
+
dataset_config = config["dataset"]
|
| 34 |
+
model_config = config["model"]
|
| 35 |
+
training_config = config["training"]
|
| 36 |
+
peft_config = config["peft"]
|
| 37 |
+
precision_config = config["precision"]
|
| 38 |
+
logging_config = config["logging"]
|
| 39 |
+
|
| 40 |
+
cmd = [
|
| 41 |
+
"python", "finetune_heaven.py",
|
| 42 |
+
"--model_name_or_path", model_config["name_or_path"],
|
| 43 |
+
"--dataset_path", dataset_config["output_path"],
|
| 44 |
+
"--output_dir", model_config["output_dir"],
|
| 45 |
+
"--num_epochs", str(training_config["num_epochs"]),
|
| 46 |
+
"--batch_size", str(training_config["batch_size"]),
|
| 47 |
+
"--gradient_accumulation_steps", str(training_config["gradient_accumulation_steps"]),
|
| 48 |
+
"--learning_rate", str(training_config["learning_rate"]),
|
| 49 |
+
"--weight_decay", str(training_config["weight_decay"]),
|
| 50 |
+
"--max_grad_norm", str(training_config["max_grad_norm"]),
|
| 51 |
+
"--warmup_ratio", str(training_config["warmup_ratio"]),
|
| 52 |
+
"--eval_ratio", str(training_config["eval_ratio"]),
|
| 53 |
+
"--max_seq_length", str(training_config["max_seq_length"]),
|
| 54 |
+
"--logging_steps", str(logging_config["logging_steps"]),
|
| 55 |
+
"--eval_steps", str(logging_config["eval_steps"]),
|
| 56 |
+
"--save_steps", str(logging_config["save_steps"]),
|
| 57 |
+
"--run_name", logging_config["run_name"],
|
| 58 |
+
"--compute_dtype", precision_config["compute_dtype"]
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# Add boolean flags
|
| 62 |
+
if peft_config["use_lora"]:
|
| 63 |
+
cmd.append("--use_lora")
|
| 64 |
+
if peft_config["use_qlora"]:
|
| 65 |
+
cmd.append("--use_qlora")
|
| 66 |
+
if precision_config["fp16"]:
|
| 67 |
+
cmd.append("--fp16")
|
| 68 |
+
if precision_config["bf16"]:
|
| 69 |
+
cmd.append("--bf16")
|
| 70 |
+
if logging_config["use_wandb"]:
|
| 71 |
+
cmd.append("--use_wandb")
|
| 72 |
+
|
| 73 |
+
# Add LoRA parameters
|
| 74 |
+
cmd.extend(["--lora_r", str(peft_config["lora_r"])])
|
| 75 |
+
cmd.extend(["--lora_alpha", str(peft_config["lora_alpha"])])
|
| 76 |
+
cmd.extend(["--lora_dropout", str(peft_config["lora_dropout"])])
|
| 77 |
+
|
| 78 |
+
print("Starting fine-tuning process...")
|
| 79 |
+
result = subprocess.run(cmd)
|
| 80 |
+
if result.returncode != 0:
|
| 81 |
+
print("Fine-tuning failed!")
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
print(f"Fine-tuning completed successfully! Model saved to {model_config['output_dir']}")
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
def main():
|
| 88 |
+
parser = argparse.ArgumentParser(description="Run the Heaven fine-tuning pipeline")
|
| 89 |
+
parser.add_argument("--config", type=str, default="config.yaml", help="Path to the configuration file")
|
| 90 |
+
parser.add_argument("--skip-dataset", action="store_true", help="Skip dataset creation step")
|
| 91 |
+
args = parser.parse_args()
|
| 92 |
+
|
| 93 |
+
print(f"Loading configuration from {args.config}...")
|
| 94 |
+
config = load_config(args.config)
|
| 95 |
+
|
| 96 |
+
# Create necessary directories
|
| 97 |
+
os.makedirs(os.path.dirname(config["dataset"]["output_path"]), exist_ok=True)
|
| 98 |
+
os.makedirs(config["model"]["output_dir"], exist_ok=True)
|
| 99 |
+
|
| 100 |
+
# Create dataset if not skipped
|
| 101 |
+
if not args.skip_dataset:
|
| 102 |
+
success = create_dataset(config)
|
| 103 |
+
if not success:
|
| 104 |
+
return
|
| 105 |
+
else:
|
| 106 |
+
print("Skipping dataset creation...")
|
| 107 |
+
|
| 108 |
+
# Fine-tune the model
|
| 109 |
+
finetune_model(config)
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
main()
|
upload_to_hub.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import shutil
|
| 4 |
+
from huggingface_hub import HfApi, create_repo, upload_folder
|
| 5 |
+
|
| 6 |
+
def upload_model(args):
|
| 7 |
+
"""Upload the fine-tuned model to Hugging Face Hub."""
|
| 8 |
+
print(f"Preparing to upload Heaven1-base Guardian model to {args.org}/{args.model_id}")
|
| 9 |
+
|
| 10 |
+
# Create a temporary directory for preparing the model
|
| 11 |
+
temp_dir = "temp_upload"
|
| 12 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
# Copy model files
|
| 15 |
+
print("Copying model files...")
|
| 16 |
+
if os.path.exists(args.model_path):
|
| 17 |
+
for item in os.listdir(args.model_path):
|
| 18 |
+
source = os.path.join(args.model_path, item)
|
| 19 |
+
dest = os.path.join(temp_dir, item)
|
| 20 |
+
if os.path.isdir(source):
|
| 21 |
+
shutil.copytree(source, dest, dirs_exist_ok=True)
|
| 22 |
+
else:
|
| 23 |
+
shutil.copy2(source, dest)
|
| 24 |
+
else:
|
| 25 |
+
print(f"Warning: Model directory {args.model_path} not found.")
|
| 26 |
+
|
| 27 |
+
# Copy documentation files
|
| 28 |
+
print("Copying documentation files...")
|
| 29 |
+
docs_files = ["README.md", "model_card.md", "Heaven1-guardian.png"]
|
| 30 |
+
for file in docs_files:
|
| 31 |
+
if os.path.exists(file):
|
| 32 |
+
shutil.copy2(file, os.path.join(temp_dir, file))
|
| 33 |
+
|
| 34 |
+
# Rename model_card.md to README.md for proper display on the Hub
|
| 35 |
+
if os.path.exists(os.path.join(temp_dir, "model_card.md")):
|
| 36 |
+
print("Using model_card.md as the main README for the Hub...")
|
| 37 |
+
if os.path.exists(os.path.join(temp_dir, "README.md")):
|
| 38 |
+
# If both exist, rename the original README to avoid overwriting
|
| 39 |
+
os.rename(os.path.join(temp_dir, "README.md"), os.path.join(temp_dir, "DETAILED_README.md"))
|
| 40 |
+
os.rename(os.path.join(temp_dir, "model_card.md"), os.path.join(temp_dir, "README.md"))
|
| 41 |
+
|
| 42 |
+
# Initialize Hugging Face API
|
| 43 |
+
api = HfApi()
|
| 44 |
+
|
| 45 |
+
# Create repository if it doesn't exist
|
| 46 |
+
try:
|
| 47 |
+
print(f"Creating repository: {args.org}/{args.model_id}")
|
| 48 |
+
create_repo(
|
| 49 |
+
repo_id=f"{args.org}/{args.model_id}",
|
| 50 |
+
token=args.token,
|
| 51 |
+
private=args.private,
|
| 52 |
+
repo_type="model",
|
| 53 |
+
exist_ok=True,
|
| 54 |
+
)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"Repository creation error (it might already exist): {e}")
|
| 57 |
+
|
| 58 |
+
# Upload model to Hugging Face Hub
|
| 59 |
+
print(f"Uploading files to {args.org}/{args.model_id}...")
|
| 60 |
+
response = upload_folder(
|
| 61 |
+
folder_path=temp_dir,
|
| 62 |
+
repo_id=f"{args.org}/{args.model_id}",
|
| 63 |
+
token=args.token,
|
| 64 |
+
repo_type="model",
|
| 65 |
+
ignore_patterns=[".*", "__pycache__/*", "temp_upload/*"],
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
print(f"Upload complete! Model available at: https://huggingface.co/{args.org}/{args.model_id}")
|
| 69 |
+
|
| 70 |
+
# Clean up
|
| 71 |
+
if not args.keep_temp:
|
| 72 |
+
print("Cleaning up temporary directory...")
|
| 73 |
+
shutil.rmtree(temp_dir)
|
| 74 |
+
|
| 75 |
+
return response
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
parser = argparse.ArgumentParser(description="Upload Heaven1-base Guardian model to Hugging Face Hub")
|
| 79 |
+
parser.add_argument("--model_path", type=str, default="./heaven1-base-8b",
|
| 80 |
+
help="Path to the fine-tuned model directory")
|
| 81 |
+
parser.add_argument("--org", type=str, default="safecircleia",
|
| 82 |
+
help="Organization name on Hugging Face Hub")
|
| 83 |
+
parser.add_argument("--model_id", type=str, default="heaven1-base-guardian",
|
| 84 |
+
help="Model ID for the repository")
|
| 85 |
+
parser.add_argument("--token", type=str, required=True,
|
| 86 |
+
help="Hugging Face authentication token")
|
| 87 |
+
parser.add_argument("--private", action="store_true",
|
| 88 |
+
help="Whether to make the repository private")
|
| 89 |
+
parser.add_argument("--keep_temp", action="store_true",
|
| 90 |
+
help="Keep temporary upload directory after completion")
|
| 91 |
+
|
| 92 |
+
args = parser.parse_args()
|
| 93 |
+
|
| 94 |
+
upload_model(args)
|