Tomas commited on
Commit
58af2e6
·
unverified ·
1 Parent(s): 9346c5e

Add initial project setup with model configuration, requirements, and upload script

Browse files
Heaven1-guardian.png ADDED
README.md CHANGED
@@ -1,3 +1,109 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Heaven1-base: Guardian
2
+
3
+ ![Heaven1-base Guardian Banner](Heaven1-guardian.png)
4
+
5
+ ## Overview
6
+
7
+ Heaven1-base (codename: "Guardian") is a specialized AI model fine-tuned from Llama 3.2 to detect predatory behavior in text messages. Designed as a protective tool, Guardian analyzes conversations to identify potentially harmful interactions, making digital spaces safer for vulnerable individuals.
8
+
9
+ The model has been trained to recognize various tactics commonly employed by predators, including:
10
+
11
+ - Grooming language and manipulation
12
+ - Attempts to isolate victims from support networks
13
+ - Requests for personal information or images
14
+ - Attempts to move conversations to more private platforms
15
+ - Emotional manipulation tactics
16
+ - Inappropriate sexual content
17
+
18
+ ## Technical Details
19
+
20
+ - **Base Model**: Meta-Llama-3.2-8B-Instruct
21
+ - **Training Method**: Parameter-Efficient Fine-Tuning (PEFT) using QLoRA
22
+ - **Training Dataset**: Carefully crafted synthetic dataset representing various predatory conversation patterns
23
+ - **Task**: Text message analysis and predatory behavior detection with detailed explanations
24
+
25
+ ## Usage
26
+
27
+ ### Input Format
28
+
29
+ The model expects input in the following format:
30
+
31
+ ```
32
+ <|system|>
33
+ You are Heaven, an AI designed to detect predatory behavior in text messages. Analyze the following message and determine if it contains predatory behavior. Provide a detailed explanation for your assessment.
34
+ <|user|>
35
+ [TEXT MESSAGE TO ANALYZE]
36
+ <|assistant|>
37
+ ```
38
+
39
+ ### Output Format
40
+
41
+ The model will respond with a detection result and detailed explanation:
42
+
43
+ ```
44
+ PREDATORY BEHAVIOR DETECTED. This message contains multiple warning signs: (1) [Warning Sign 1], (2) [Warning Sign 2], etc. These are common tactics used by predators to manipulate potential victims.
45
+
46
+ OR
47
+
48
+ NO PREDATORY BEHAVIOR DETECTED. This message contains normal friendly communication. [Additional context about the message]. There are no attempts at manipulation, isolation, inappropriate requests, or other warning signs of predatory behavior.
49
+ ```
50
+
51
+ ### Example Usage with Transformers
52
+
53
+ ```python
54
+ from transformers import AutoModelForCausalLM, AutoTokenizer
55
+
56
+ # Load model and tokenizer
57
+ model_path = "safecircleia/heaven1-base"
58
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
59
+ model = AutoModelForCausalLM.from_pretrained(model_path)
60
+
61
+ # Message to analyze
62
+ message_to_analyze = "Hey, I know we just met but I feel like we have a special connection. Don't tell your parents about our chats, they wouldn't understand. Can you send me a picture of yourself?"
63
+
64
+ # Format the prompt
65
+ prompt = f"""<|system|>
66
+ You are Heaven, an AI designed to detect predatory behavior in text messages. Analyze the following message and determine if it contains predatory behavior. Provide a detailed explanation for your assessment.
67
+ <|user|>
68
+ {message_to_analyze}
69
+ <|assistant|>
70
+ """
71
+
72
+ # Generate response
73
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
74
+ outputs = model.generate(inputs["input_ids"], max_new_tokens=512)
75
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
+
77
+ print(response)
78
+ ```
79
+
80
+ ## Ethical Considerations
81
+
82
+ - This model is designed as a protective tool to help identify potentially harmful communication patterns.
83
+ - False positives and false negatives are possible; human review should be employed for critical applications.
84
+ - The model should be used as part of a broader safety framework, not as the sole decision-maker.
85
+ - Privacy and consent are essential when analyzing communications.
86
+
87
+ ## Limitations
88
+
89
+ - The model detects patterns based on its training data and may miss novel predatory tactics.
90
+ - Cultural and contextual nuances may impact detection accuracy.
91
+ - The model is not a substitute for human judgment in safeguarding vulnerable individuals.
92
+
93
+ ## Citation
94
+
95
+ If you use Heaven1-base Guardian in your research or applications, please cite:
96
+
97
+ ```
98
+ @misc{heaven1-base-2025,
99
+ author = {SafeCircleIA},
100
+ title = {Heaven1-base: Guardian - Predatory Behavior Detection Model},
101
+ year = {2024},
102
+ publisher = {Hugging Face},
103
+ howpublished = {\url{https://huggingface.co/safecircleia/heaven1-base-guardian}}
104
+ }
105
+ ```
106
+
107
+ ## Contact
108
+
109
+ For questions, feedback, or concerns about the Heaven1-base Guardian model, please contact SafeCircleIA through Hugging Face or via contact@safecircle.tech.
check_torch.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import sys
4
+ import importlib
5
+
6
+ def check_installations():
7
+ """Check PyTorch and torchvision installations."""
8
+ print("Python version:", sys.version)
9
+ print("PyTorch version:", torch.__version__)
10
+ print("torchvision version:", torchvision.__version__)
11
+ print("CUDA available:", torch.cuda.is_available())
12
+ if torch.cuda.is_available():
13
+ print("CUDA version:", torch.version.cuda)
14
+ print("cuDNN version:", torch.backends.cudnn.version())
15
+
16
+ # Check if PyTorch and torchvision versions are compatible
17
+ torch_version = torch.__version__.split('.')
18
+ torchvision_version = torchvision.__version__.split('.')
19
+
20
+ if torch_version[0] != torchvision_version[0] or torch_version[1] != torchvision_version[1]:
21
+ print("WARNING: PyTorch and torchvision versions might be incompatible!")
22
+ print("It's recommended to have matching major and minor version numbers.")
23
+
24
+ # Check for NMS operator
25
+ try:
26
+ print("\nAttempting to import torchvision.ops...")
27
+ from torchvision.ops import nms
28
+ print("Successfully imported NMS operator.")
29
+
30
+ # Create dummy data to test NMS
31
+ boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]], dtype=torch.float32)
32
+ scores = torch.tensor([0.9, 0.8], dtype=torch.float32)
33
+
34
+ print("Testing NMS functionality...")
35
+ indices = nms(boxes, scores, 0.5)
36
+ print("NMS test successful! Result:", indices)
37
+ except Exception as e:
38
+ print(f"Error importing or using NMS: {e}")
39
+
40
+ # Check for dependencies that might use NMS
41
+ print("\nChecking dependencies that might use NMS operator...")
42
+ deps_to_check = [
43
+ 'trl', 'transformers', 'peft', 'accelerate',
44
+ 'bitsandbytes', 'datasets'
45
+ ]
46
+
47
+ for dep in deps_to_check:
48
+ try:
49
+ module = importlib.import_module(dep)
50
+ version = getattr(module, '__version__', 'unknown')
51
+ print(f"✓ {dep} version: {version}")
52
+ except ImportError:
53
+ print(f"✗ {dep} not installed")
54
+
55
+ if __name__ == "__main__":
56
+ check_installations()
config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Heaven Model Configuration for Llama 3.2 Fine-tuning
2
+
3
+ # Dataset configuration
4
+ dataset:
5
+ size: 10000 # Number of examples to generate
6
+ predatory_ratio: 0.5 # Ratio of predatory examples (0-1)
7
+ output_path: "data/heaven_dataset.jsonl"
8
+
9
+ # Model configuration
10
+ model:
11
+ name_or_path: "meta-llama/Llama-3.2-3B-Instruct" # HuggingFace model identifier
12
+ output_dir: "./heaven1-base-8b" # Directory to save fine-tuned model
13
+
14
+ # Training configuration
15
+ training:
16
+ num_epochs: 3 # Number of training epochs
17
+ batch_size: 1 # Batch size per device
18
+ gradient_accumulation_steps: 8 # Number of steps to accumulate gradients
19
+ learning_rate: 2e-5 # Initial learning rate
20
+ weight_decay: 0.01 # Weight decay coefficient
21
+ max_grad_norm: 1.0 # Max gradient norm for clipping
22
+ warmup_ratio: 0.1 # Linear warmup ratio
23
+ eval_ratio: 0.1 # Portion of data used for evaluation
24
+ max_seq_length: 4096 # Maximum sequence length
25
+
26
+ # PEFT configuration (Parameter-Efficient Fine-Tuning)
27
+ peft:
28
+ use_lora: true # Whether to use LoRA
29
+ use_qlora: true # Whether to use QLoRA (quantized LoRA)
30
+ lora_r: 16 # LoRA rank
31
+ lora_alpha: 32 # LoRA scaling factor
32
+ lora_dropout: 0.05 # LoRA dropout rate
33
+
34
+ # Precision configuration
35
+ precision:
36
+ fp16: false # Whether to use fp16 mixed precision
37
+ bf16: true # Whether to use bf16 mixed precision
38
+ compute_dtype: "float16" # Compute dtype for quantization
39
+
40
+ # Logging configuration
41
+ logging:
42
+ use_wandb: false # Whether to use Weights & Biases
43
+ run_name: "heaven-llama3-2" # Name of the run
44
+ logging_steps: 10 # Steps between logging
45
+ eval_steps: 100 # Steps between evaluation
46
+ save_steps: 100 # Steps between saving checkpoints
create_dataset.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from typing import List, Dict, Any
5
+ import argparse
6
+ import jsonlines
7
+ from tqdm import tqdm
8
+
9
+ # Define common patterns for predatory behavior
10
+ PREDATORY_PATTERNS = {
11
+ "grooming": [
12
+ "You seem so mature for your age",
13
+ "I feel like we have a special connection",
14
+ "You're not like other kids/people your age",
15
+ "I understand you better than your parents/friends",
16
+ "You can trust me with anything",
17
+ ],
18
+ "isolation": [
19
+ "Don't tell your parents about our conversations",
20
+ "This should be just between us",
21
+ "Your friends wouldn't understand our relationship",
22
+ "You don't need to tell anyone about this",
23
+ "Let's keep this our secret",
24
+ ],
25
+ "personal_info": [
26
+ "Where do you live exactly?",
27
+ "Are you home alone right now?",
28
+ "What's your address?",
29
+ "When are your parents usually out?",
30
+ "What school do you go to?",
31
+ ],
32
+ "image_requests": [
33
+ "Send me a picture of yourself",
34
+ "I'd love to see what you look like",
35
+ "Can you send a photo?",
36
+ "Do you have any pictures you could share?",
37
+ "Send me a pic",
38
+ ],
39
+ "private_platform": [
40
+ "Let's talk on Snapchat instead",
41
+ "Do you have WhatsApp? We should chat there",
42
+ "Give me your Instagram, we can DM",
43
+ "Let's move this conversation somewhere more private",
44
+ "What's your Discord? It's better to chat there",
45
+ ],
46
+ "emotional_manipulation": [
47
+ "If you really cared about me, you would...",
48
+ "I thought you were different, but I guess I was wrong",
49
+ "After everything I've done for you...",
50
+ "No one has ever made me feel this way before",
51
+ "You're making me sad by not trusting me",
52
+ ],
53
+ "sexual_content": [
54
+ "Have you ever kissed anyone?",
55
+ "Are you comfortable talking about intimate things?",
56
+ "Do you know what [explicit term] means?",
57
+ "Have you ever seen [explicit content]?",
58
+ "I could teach you about [suggestive content]",
59
+ ]
60
+ }
61
+
62
+ # Define normal conversation patterns
63
+ NORMAL_PATTERNS = {
64
+ "greetings": [
65
+ "Hey, how are you doing today?",
66
+ "Hi there! How's your day going?",
67
+ "Hello! How have you been?",
68
+ "Good morning! Ready for the day?",
69
+ "Hey, just checking in to see how you're doing",
70
+ ],
71
+ "casual_plans": [
72
+ "Want to grab coffee sometime?",
73
+ "Are you going to the event this weekend?",
74
+ "We should hang out at the mall with the group",
75
+ "Do you want to see that new movie?",
76
+ "Let's meet up at the park with everyone",
77
+ ],
78
+ "interests": [
79
+ "What kind of music do you listen to?",
80
+ "Have you watched any good shows lately?",
81
+ "What are your hobbies?",
82
+ "Did you see the game last night?",
83
+ "What books are you reading these days?",
84
+ ],
85
+ "school_work": [
86
+ "How are your classes going?",
87
+ "Did you finish that assignment?",
88
+ "Do you understand the math homework?",
89
+ "I'm struggling with this project, any advice?",
90
+ "Are you ready for the test tomorrow?",
91
+ ],
92
+ "support": [
93
+ "I'm here if you need to talk",
94
+ "Hope things get better soon",
95
+ "Let me know if you need anything",
96
+ "That sounds tough, how are you handling it?",
97
+ "I believe in you, you can do this",
98
+ ]
99
+ }
100
+
101
+ def generate_predatory_message() -> str:
102
+ """Generate a synthetic predatory message with multiple red flags."""
103
+ message_parts = []
104
+
105
+ # Select 2-4 pattern categories at random
106
+ categories = random.sample(list(PREDATORY_PATTERNS.keys()), random.randint(2, 4))
107
+
108
+ # Add a greeting sometimes
109
+ if random.random() < 0.7:
110
+ message_parts.append(random.choice(NORMAL_PATTERNS["greetings"]))
111
+
112
+ # Add predatory patterns
113
+ for category in categories:
114
+ message_parts.append(random.choice(PREDATORY_PATTERNS[category]))
115
+
116
+ # Sometimes mix in normal conversation to make it less obvious
117
+ if random.random() < 0.5:
118
+ normal_category = random.choice(list(NORMAL_PATTERNS.keys()))
119
+ message_parts.append(random.choice(NORMAL_PATTERNS[normal_category]))
120
+
121
+ # Shuffle the parts to create a more natural conversation
122
+ if len(message_parts) > 2: # Keep greeting first if it exists
123
+ first_part = message_parts[0] if random.random() < 0.7 else ""
124
+ remaining_parts = message_parts[1:] if first_part else message_parts
125
+ random.shuffle(remaining_parts)
126
+ if first_part:
127
+ message_parts = [first_part] + remaining_parts
128
+ else:
129
+ message_parts = remaining_parts
130
+
131
+ return " ".join(message_parts)
132
+
133
+ def generate_normal_message() -> str:
134
+ """Generate a synthetic normal message."""
135
+ message_parts = []
136
+
137
+ # Select 2-3 pattern categories at random
138
+ categories = random.sample(list(NORMAL_PATTERNS.keys()), random.randint(2, 3))
139
+
140
+ for category in categories:
141
+ message_parts.append(random.choice(NORMAL_PATTERNS[category]))
142
+
143
+ return " ".join(message_parts)
144
+
145
+ def generate_predatory_explanation(message: str) -> str:
146
+ """Generate an explanation for why a message is predatory."""
147
+ explanation = "PREDATORY BEHAVIOR DETECTED. This message contains multiple warning signs: "
148
+ warning_signs = []
149
+
150
+ for category, patterns in PREDATORY_PATTERNS.items():
151
+ for pattern in patterns:
152
+ if pattern.lower() in message.lower():
153
+ if category == "grooming":
154
+ warning_signs.append(f"Grooming language ('{pattern}')")
155
+ elif category == "isolation":
156
+ warning_signs.append(f"Isolation attempt ('{pattern}')")
157
+ elif category == "personal_info":
158
+ warning_signs.append(f"Seeking personal information ('{pattern}')")
159
+ elif category == "image_requests":
160
+ warning_signs.append(f"Request for images ('{pattern}')")
161
+ elif category == "private_platform":
162
+ warning_signs.append(f"Attempting to move to private communication ('{pattern}')")
163
+ elif category == "emotional_manipulation":
164
+ warning_signs.append(f"Emotional manipulation ('{pattern}')")
165
+ elif category == "sexual_content":
166
+ warning_signs.append(f"Inappropriate sexual content ('{pattern}')")
167
+
168
+ for i, sign in enumerate(warning_signs):
169
+ if i == 0:
170
+ explanation += f"(1) {sign}"
171
+ else:
172
+ explanation += f", ({i+1}) {sign}"
173
+
174
+ explanation += ". These are common tactics used by predators to manipulate potential victims."
175
+
176
+ return explanation
177
+
178
+ def generate_normal_explanation(message: str) -> str:
179
+ """Generate an explanation for why a message is not predatory."""
180
+ explanation = "NO PREDATORY BEHAVIOR DETECTED. This message contains normal friendly communication. "
181
+
182
+ if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["greetings"]):
183
+ explanation += "It includes a casual greeting. "
184
+ if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["casual_plans"]):
185
+ explanation += "It contains appropriate social plans. "
186
+ if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["interests"]):
187
+ explanation += "It shows interest in common topics. "
188
+ if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["school_work"]):
189
+ explanation += "It discusses school or work. "
190
+ if any(pattern.lower() in message.lower() for pattern in NORMAL_PATTERNS["support"]):
191
+ explanation += "It offers appropriate support. "
192
+
193
+ explanation += "There are no attempts at manipulation, isolation, inappropriate requests, or other warning signs of predatory behavior."
194
+
195
+ return explanation
196
+
197
+ def create_dataset_entry(predatory: bool = False) -> Dict[str, List[Dict[str, str]]]:
198
+ """Create a single dataset entry in the format required for Llama 3.2 fine-tuning."""
199
+ system_message = "You are Heaven, an AI designed to detect predatory behavior in text messages. Analyze the following message and determine if it contains predatory behavior. Provide a detailed explanation for your assessment."
200
+
201
+ if predatory:
202
+ user_message = generate_predatory_message()
203
+ assistant_message = generate_predatory_explanation(user_message)
204
+ else:
205
+ user_message = generate_normal_message()
206
+ assistant_message = generate_normal_explanation(user_message)
207
+
208
+ return {
209
+ "messages": [
210
+ {"role": "system", "content": system_message},
211
+ {"role": "user", "content": user_message},
212
+ {"role": "assistant", "content": assistant_message}
213
+ ]
214
+ }
215
+
216
+ def create_dataset(size: int, predatory_ratio: float = 0.5, output_path: str = "heaven_dataset.jsonl"):
217
+ """Create a full dataset with the specified size and ratio of predatory examples."""
218
+ predatory_count = int(size * predatory_ratio)
219
+ normal_count = size - predatory_count
220
+
221
+ dataset = []
222
+
223
+ print(f"Generating {predatory_count} predatory examples...")
224
+ for _ in tqdm(range(predatory_count)):
225
+ dataset.append(create_dataset_entry(predatory=True))
226
+
227
+ print(f"Generating {normal_count} normal examples...")
228
+ for _ in tqdm(range(normal_count)):
229
+ dataset.append(create_dataset_entry(predatory=False))
230
+
231
+ # Shuffle the dataset
232
+ random.shuffle(dataset)
233
+
234
+ # Save to JSONL file
235
+ with jsonlines.open(output_path, mode='w') as writer:
236
+ for entry in dataset:
237
+ writer.write(entry)
238
+
239
+ print(f"Dataset saved to {output_path}")
240
+
241
+ if __name__ == "__main__":
242
+ parser = argparse.ArgumentParser(description="Generate a synthetic dataset for predatory behavior detection")
243
+ parser.add_argument("--size", type=int, default=1000, help="Total size of the dataset")
244
+ parser.add_argument("--ratio", type=float, default=0.5, help="Ratio of predatory examples (0-1)")
245
+ parser.add_argument("--output", type=str, default="data/heaven_dataset.jsonl", help="Output file path")
246
+ args = parser.parse_args()
247
+
248
+ # Create output directory if it doesn't exist
249
+ os.makedirs(os.path.dirname(args.output), exist_ok=True)
250
+
251
+ create_dataset(args.size, args.ratio, args.output)
data/heaven_dataset.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
finetune_heaven.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import numpy as np
5
+ import time
6
+ from datasets import load_dataset
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ TrainingArguments,
11
+ Trainer,
12
+ DataCollatorForLanguageModeling,
13
+ BitsAndBytesConfig
14
+ )
15
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
16
+ from trl import SFTTrainer
17
+ import wandb
18
+
19
+ def format_prompt(examples):
20
+ """
21
+ Format the prompts for Llama 3.2 instruction fine-tuning.
22
+ This function processes batches of examples and returns formatted strings.
23
+ """
24
+ # Process each example in the batch
25
+ formatted_prompts = []
26
+
27
+ # Handle batch processing - if it's a single example, wrap it in a list
28
+ messages_list = examples["messages"] if isinstance(examples, dict) and "messages" in examples else [examples]
29
+
30
+ # Process each message in the batch
31
+ for messages in messages_list:
32
+ if isinstance(messages, list) and len(messages) >= 3:
33
+ system = messages[0]["content"] if isinstance(messages[0], dict) and "content" in messages[0] else ""
34
+ user_message = messages[1]["content"] if isinstance(messages[1], dict) and "content" in messages[1] else ""
35
+ assistant_message = messages[2]["content"] if isinstance(messages[2], dict) and "content" in messages[2] else ""
36
+ else:
37
+ # Fallback for unexpected structure
38
+ print(f"Warning: Unexpected message structure: {messages}")
39
+ system = ""
40
+ user_message = ""
41
+ assistant_message = ""
42
+
43
+ # Format the prompt
44
+ formatted = f"<|system|>\n{system}\n<|user|>\n{user_message}\n<|assistant|>\n{assistant_message}"
45
+ formatted_prompts.append(formatted)
46
+
47
+ return formatted_prompts
48
+
49
+ def preprocess_function(examples, tokenizer, max_length=4096):
50
+ """
51
+ Tokenize the examples for training
52
+ """
53
+ # Get formatted prompts
54
+ formatted_prompts = [format_prompt(example) for example in examples["messages"]]
55
+
56
+ # Tokenize
57
+ tokenized_inputs = tokenizer(
58
+ formatted_prompts,
59
+ padding="max_length",
60
+ truncation=True,
61
+ max_length=max_length,
62
+ return_tensors="pt",
63
+ )
64
+
65
+ # Create labels (same as input_ids since we're doing causal LM training)
66
+ tokenized_inputs["labels"] = tokenized_inputs["input_ids"].clone()
67
+
68
+ return tokenized_inputs
69
+
70
+ def train(args):
71
+ print("Initializing training process...")
72
+
73
+ # Initialize wandb if tracking is enabled
74
+ if args.use_wandb:
75
+ print("Initializing Weights & Biases...")
76
+ wandb.init(project="heaven-llama3-2", name=args.run_name)
77
+
78
+ # Load the tokenizer
79
+ print("Loading tokenizer...")
80
+ tokenizer = AutoTokenizer.from_pretrained(
81
+ args.model_name_or_path,
82
+ trust_remote_code=True,
83
+ )
84
+ tokenizer.pad_token = tokenizer.eos_token
85
+ tokenizer.padding_side = "right"
86
+ print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")
87
+
88
+ # Configure quantization if using QLoRA
89
+ print("Setting up model configuration...")
90
+ if args.use_qlora:
91
+ print("Configuring QLoRA quantization...")
92
+ compute_dtype = getattr(torch, args.compute_dtype)
93
+ quantization_config = BitsAndBytesConfig(
94
+ load_in_4bit=True,
95
+ bnb_4bit_compute_dtype=compute_dtype,
96
+ bnb_4bit_use_double_quant=True,
97
+ bnb_4bit_quant_type="nf4",
98
+ )
99
+ else:
100
+ quantization_config = None
101
+
102
+ # Load the model with use_cache=False for compatibility with gradient checkpointing
103
+ print(f"Loading model: {args.model_name_or_path}")
104
+ print(f"GPU available: {torch.cuda.is_available()}, Device count: {torch.cuda.device_count()}")
105
+ if torch.cuda.is_available():
106
+ for i in range(torch.cuda.device_count()):
107
+ print(f"GPU {i}: {torch.cuda.get_device_name(i)}, "
108
+ f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
109
+
110
+ start_time = time.time()
111
+ model = AutoModelForCausalLM.from_pretrained(
112
+ args.model_name_or_path,
113
+ quantization_config=quantization_config,
114
+ device_map="auto",
115
+ trust_remote_code=True,
116
+ use_cache=False # Set use_cache=False explicitly for gradient checkpointing
117
+ )
118
+ load_time = time.time() - start_time
119
+ print(f"Model loaded in {load_time:.2f} seconds")
120
+
121
+ # Prepare model for k-bit training if using QLoRA
122
+ if args.use_qlora:
123
+ print("Preparing model for k-bit training...")
124
+ model = prepare_model_for_kbit_training(model)
125
+
126
+ # Set up LoRA configuration
127
+ peft_config = None
128
+ if args.use_lora:
129
+ print("Setting up LoRA configuration...")
130
+ peft_config = LoraConfig(
131
+ r=args.lora_r,
132
+ lora_alpha=args.lora_alpha,
133
+ lora_dropout=args.lora_dropout,
134
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
135
+ bias="none",
136
+ task_type="CAUSAL_LM",
137
+ )
138
+ model = get_peft_model(model, peft_config)
139
+ model.print_trainable_parameters()
140
+
141
+ # Load and prepare dataset
142
+ print(f"Loading dataset from {args.dataset_path}...")
143
+ start_time = time.time()
144
+ dataset = load_dataset("json", data_files=args.dataset_path)
145
+ print(f"Dataset loaded in {time.time() - start_time:.2f} seconds. Size: {len(dataset['train'])} examples")
146
+
147
+ # Process the dataset manually to ensure correct formatting
148
+ print("Preprocessing dataset...")
149
+ start_time = time.time()
150
+
151
+ def preprocess_function(examples):
152
+ if len(examples["messages"]) % 100 == 0:
153
+ print(f"Processing batch of {len(examples['messages'])} examples...")
154
+ formatted_texts = format_prompt(examples)
155
+ return tokenizer(
156
+ formatted_texts,
157
+ padding="max_length",
158
+ truncation=True,
159
+ max_length=args.max_seq_length,
160
+ return_tensors=None # Return Python lists
161
+ )
162
+
163
+ processed_dataset = dataset["train"].map(
164
+ preprocess_function,
165
+ batched=True,
166
+ batch_size=100,
167
+ remove_columns=["messages"],
168
+ desc="Processing dataset"
169
+ )
170
+ print(f"Dataset preprocessing completed in {time.time() - start_time:.2f} seconds")
171
+
172
+ # Split dataset into train and evaluation
173
+ print(f"Splitting dataset with test_size={args.eval_ratio}...")
174
+ split_dataset = processed_dataset.train_test_split(test_size=args.eval_ratio)
175
+ print(f"Train set: {len(split_dataset['train'])} examples, Test set: {len(split_dataset['test'])} examples")
176
+
177
+ # Create a data collator for language modeling
178
+ data_collator = DataCollatorForLanguageModeling(
179
+ tokenizer=tokenizer,
180
+ mlm=False
181
+ )
182
+
183
+ # Set up training arguments
184
+ print("Configuring training arguments...")
185
+ training_args = TrainingArguments(
186
+ output_dir=args.output_dir,
187
+ num_train_epochs=args.num_epochs,
188
+ per_device_train_batch_size=args.batch_size,
189
+ per_device_eval_batch_size=args.batch_size,
190
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
191
+ eval_strategy="steps", # Use newer parameter name
192
+ save_strategy="steps",
193
+ eval_steps=args.eval_steps,
194
+ save_steps=args.save_steps,
195
+ logging_steps=args.logging_steps,
196
+ learning_rate=args.learning_rate,
197
+ weight_decay=args.weight_decay,
198
+ fp16=args.fp16,
199
+ bf16=args.bf16,
200
+ max_grad_norm=args.max_grad_norm,
201
+ max_steps=args.max_steps,
202
+ warmup_ratio=args.warmup_ratio,
203
+ group_by_length=False,
204
+ lr_scheduler_type=args.lr_scheduler_type,
205
+ report_to="wandb" if args.use_wandb else "none",
206
+ save_total_limit=3,
207
+ remove_unused_columns=False,
208
+ load_best_model_at_end=True,
209
+ metric_for_best_model="eval_loss",
210
+ # Add gradient checkpointing settings
211
+ gradient_checkpointing=True,
212
+ gradient_checkpointing_kwargs={"use_reentrant": False}, # Explicitly set use_reentrant=False
213
+ )
214
+
215
+ # Create the SFT trainer
216
+ print("Initializing SFTTrainer...")
217
+ trainer = SFTTrainer(
218
+ model=model,
219
+ train_dataset=split_dataset["train"],
220
+ eval_dataset=split_dataset["test"],
221
+ args=training_args,
222
+ tokenizer=tokenizer,
223
+ # Remove formatting_func since we're pre-processing the dataset
224
+ max_seq_length=args.max_seq_length,
225
+ # Pass peft_config separately if using LoRA
226
+ peft_config=peft_config if args.use_lora else None,
227
+ )
228
+
229
+ # Train the model
230
+ print("Starting training...")
231
+ print("If training appears stuck here, the trainer might be compiling the model or allocating memory.")
232
+ print("For large models, this can take several minutes, especially on the first training step.")
233
+
234
+ try:
235
+ train_result = trainer.train()
236
+
237
+ # Save the model
238
+ print(f"Saving model to {args.output_dir}")
239
+ trainer.save_model(args.output_dir)
240
+
241
+ # Save training metrics
242
+ trainer.log_metrics("train", train_result.metrics)
243
+ trainer.save_metrics("train", train_result.metrics)
244
+ trainer.save_state()
245
+
246
+ # Evaluate model
247
+ print("Evaluating model...")
248
+ metrics = trainer.evaluate()
249
+ trainer.log_metrics("eval", metrics)
250
+ trainer.save_metrics("eval", metrics)
251
+
252
+ print(f"Training complete! Model saved to {args.output_dir}")
253
+ except Exception as e:
254
+ print(f"Error during training: {e}")
255
+ import traceback
256
+ traceback.print_exc()
257
+
258
+ # Close wandb if used
259
+ if args.use_wandb:
260
+ wandb.finish()
261
+
262
+ if __name__ == "__main__":
263
+ parser = argparse.ArgumentParser(description="Fine-tune Llama 3.2 for predatory behavior detection")
264
+
265
+ # Model and dataset arguments
266
+ parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3.2-8B-Instruct",
267
+ help="Path to pretrained model or model identifier from huggingface.co/models")
268
+ parser.add_argument("--dataset_path", type=str, required=True,
269
+ help="Path to the JSONL dataset for fine-tuning")
270
+ parser.add_argument("--output_dir", type=str, default="./heaven-model",
271
+ help="Directory to save the fine-tuned model")
272
+
273
+ # Training hyperparameters
274
+ parser.add_argument("--num_epochs", type=int, default=3,
275
+ help="Number of training epochs")
276
+ parser.add_argument("--batch_size", type=int, default=1,
277
+ help="Batch size per device")
278
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=8,
279
+ help="Number of updates steps to accumulate before performing a backward/update pass")
280
+ parser.add_argument("--learning_rate", type=float, default=2e-5,
281
+ help="Initial learning rate")
282
+ parser.add_argument("--weight_decay", type=float, default=0.01,
283
+ help="Weight decay to apply")
284
+ parser.add_argument("--max_grad_norm", type=float, default=1.0,
285
+ help="Max gradient norm for gradient clipping")
286
+ parser.add_argument("--max_steps", type=int, default=-1,
287
+ help="If > 0, set total number of training steps to perform. Overrides num_epochs")
288
+ parser.add_argument("--warmup_ratio", type=float, default=0.1,
289
+ help="Linear warmup over warmup_ratio fraction of total steps")
290
+ parser.add_argument("--eval_ratio", type=float, default=0.1,
291
+ help="Ratio of data to use for evaluation")
292
+ parser.add_argument("--lr_scheduler_type", type=str, default="cosine",
293
+ help="Learning rate scheduler type")
294
+ parser.add_argument("--max_seq_length", type=int, default=4096,
295
+ help="Maximum sequence length for training")
296
+
297
+ # Logging and evaluation arguments
298
+ parser.add_argument("--logging_steps", type=int, default=10,
299
+ help="Number of steps between logging")
300
+ parser.add_argument("--eval_steps", type=int, default=100,
301
+ help="Number of steps between evaluations")
302
+ parser.add_argument("--save_steps", type=int, default=100,
303
+ help="Number of steps between saving model checkpoints")
304
+ parser.add_argument("--run_name", type=str, default="heaven-llama3-2",
305
+ help="Name of the run for logging")
306
+ parser.add_argument("--use_wandb", action="store_true",
307
+ help="Whether to use Weights & Biases for logging")
308
+
309
+ # PEFT arguments
310
+ parser.add_argument("--use_lora", action="store_true",
311
+ help="Whether to use LoRA for fine-tuning")
312
+ parser.add_argument("--use_qlora", action="store_true",
313
+ help="Whether to use QLoRA for fine-tuning (4-bit quantization with LoRA)")
314
+ parser.add_argument("--lora_r", type=int, default=16,
315
+ help="Rank of the LoRA update matrices")
316
+ parser.add_argument("--lora_alpha", type=int, default=32,
317
+ help="Scaling factor for LoRA")
318
+ parser.add_argument("--lora_dropout", type=float, default=0.05,
319
+ help="Dropout probability for LoRA")
320
+
321
+ # Mixed precision arguments
322
+ parser.add_argument("--fp16", action="store_true",
323
+ help="Whether to use fp16 mixed precision")
324
+ parser.add_argument("--bf16", action="store_true",
325
+ help="Whether to use bf16 mixed precision")
326
+ parser.add_argument("--compute_dtype", type=str, default="float16",
327
+ help="Compute dtype for 4-bit quantization")
328
+
329
+ args = parser.parse_args()
330
+
331
+ train(args)
fix_nms_error.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import importlib.util
4
+ from pathlib import Path
5
+
6
+ def locate_trl_module():
7
+ """Find the location of the TRL module in the Python path."""
8
+ try:
9
+ spec = importlib.util.find_spec('trl')
10
+ if spec is None:
11
+ print("TRL module not found in the Python path")
12
+ return None
13
+
14
+ trl_path = Path(spec.origin).parent
15
+ print(f"Found TRL module at: {trl_path}")
16
+ return trl_path
17
+ except Exception as e:
18
+ print(f"Error locating TRL module: {e}")
19
+ return None
20
+
21
+ def patch_sft_trainer():
22
+ """Patch the SFTTrainer to avoid using torchvision's NMS operator."""
23
+ trl_path = locate_trl_module()
24
+ if trl_path is None:
25
+ return False
26
+
27
+ # Path to the trainer.py file which likely contains the NMS reference
28
+ trainer_path = trl_path / "trainer" / "sft_trainer.py"
29
+
30
+ if not trainer_path.exists():
31
+ print(f"Could not find the SFT trainer file at: {trainer_path}")
32
+ return False
33
+
34
+ print(f"Found SFT trainer file at: {trainer_path}")
35
+
36
+ # Read the file content
37
+ with open(trainer_path, "r") as f:
38
+ content = f.read()
39
+
40
+ # Check if 'torchvision' is in the file
41
+ if "torchvision" not in content:
42
+ print("No torchvision imports found in the SFT trainer file.")
43
+ return False
44
+
45
+ # Create backup
46
+ backup_path = trainer_path.with_suffix(".py.bak")
47
+ print(f"Creating backup at: {backup_path}")
48
+ with open(backup_path, "w") as f:
49
+ f.write(content)
50
+
51
+ # Replace imports - common patterns
52
+ patched_content = content
53
+
54
+ # Pattern 1: Direct import of nms
55
+ patched_content = patched_content.replace(
56
+ "from torchvision.ops import nms",
57
+ "# from torchvision.ops import nms # Commented out to fix NMS error"
58
+ )
59
+
60
+ # Pattern 2: Import torchvision
61
+ patched_content = patched_content.replace(
62
+ "import torchvision",
63
+ "# import torchvision # Commented out to fix NMS error"
64
+ )
65
+
66
+ # Pattern 3: Import from torchvision.ops
67
+ patched_content = patched_content.replace(
68
+ "from torchvision.ops",
69
+ "# from torchvision.ops # Commented out to fix NMS error"
70
+ )
71
+
72
+ # Add our custom NMS implementation
73
+ custom_nms = """
74
+ # Custom NMS implementation to avoid torchvision dependency
75
+ def nms(boxes, scores, iou_threshold):
76
+ """
77
+ Performs non-maximum suppression (NMS) on the boxes according to their
78
+ intersection-over-union (IoU).
79
+
80
+ Args:
81
+ boxes (Tensor[N, 4]): boxes to perform NMS on
82
+ scores (Tensor[N]): scores for each one of the boxes
83
+ iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
84
+
85
+ Returns:
86
+ Tensor: int64 tensor with the indices of the elements that have been kept
87
+ """
88
+ import torch
89
+
90
+ # Sort boxes by scores
91
+ _, order = scores.sort(0, descending=True)
92
+ keep = []
93
+
94
+ while order.numel() > 0:
95
+ if order.numel() == 1:
96
+ keep.append(order.item())
97
+ break
98
+
99
+ i = order[0].item()
100
+ keep.append(i)
101
+
102
+ # Compute IoU of the remaining boxes with the largest box
103
+ xx1 = torch.max(boxes[i, 0], boxes[order[1:], 0])
104
+ yy1 = torch.max(boxes[i, 1], boxes[order[1:], 1])
105
+ xx2 = torch.min(boxes[i, 2], boxes[order[1:], 2])
106
+ yy2 = torch.min(boxes[i, 3], boxes[order[1:], 3])
107
+
108
+ w = torch.clamp(xx2 - xx1, min=0.0)
109
+ h = torch.clamp(yy2 - yy1, min=0.0)
110
+ inter = w * h
111
+
112
+ # IoU = intersection / (area1 + area2 - intersection)
113
+ box_area = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1])
114
+ other_area = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1])
115
+ iou = inter / (box_area + other_area - inter)
116
+
117
+ # Keep boxes with IoU less than threshold
118
+ inds = torch.where(iou <= iou_threshold)[0]
119
+ order = order[inds + 1]
120
+
121
+ return torch.tensor(keep, dtype=torch.int64)
122
+ """
123
+
124
+ # Add our custom implementation somewhere near the imports
125
+ import_end = patched_content.find("\n\n", patched_content.find("import "))
126
+ if import_end == -1:
127
+ import_end = patched_content.find("\n", patched_content.find("import "))
128
+
129
+ patched_content = patched_content[:import_end] + custom_nms + patched_content[import_end:]
130
+
131
+ # Write the patched file
132
+ with open(trainer_path, "w") as f:
133
+ f.write(patched_content)
134
+
135
+ print(f"Successfully patched {trainer_path}")
136
+ print("The SFTTrainer should now work without requiring torchvision's NMS operator")
137
+ return True
138
+
139
+ if __name__ == "__main__":
140
+ success = patch_sft_trainer()
141
+ if success:
142
+ print("\nPatch applied successfully. You can now run the fine-tuning script.")
143
+ else:
144
+ print("\nFailed to apply the patch. Please check the error messages above.")
model_card.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: MIT
4
+ tags:
5
+ - llama
6
+ - llama-3.2
7
+ - safeguarding
8
+ - content-moderation
9
+ - safety
10
+ - predator-detection
11
+ - text-classification
12
+ datasets:
13
+ - safecircleia/heaven1-dataset
14
+ metrics:
15
+ - accuracy
16
+ - precision
17
+ - recall
18
+ - f1
19
+ pipeline_tag: text-classification
20
+ widget:
21
+ - text: "Hey, I know we just met but I feel like we have a special connection. Don't tell your parents about our chats, they wouldn't understand. Can you send me a picture of yourself?"
22
+ - text: "Hey, just checking in to see how your day went. Let me know if you want to grab coffee this weekend."
23
+ ---
24
+
25
+ # Heaven1-base: Guardian
26
+
27
+ <img src="https://huggingface.co/safecircleia/heaven1-base/resolve/main/Heaven1-guardian.png" alt="Heaven1-base Guardian Banner" width="600">
28
+
29
+ Heaven1-base (codename: "Guardian") is a specialized AI model fine-tuned from Llama 3.2 to detect predatory behavior in text messages. The model analyzes conversation patterns to identify potential warning signs of grooming, manipulation, or predatory tactics.
30
+
31
+ ## Model Details
32
+
33
+ - **Developed by:** SafeCircleIA
34
+ - **Model type:** Fine-tuned Llama 3.2
35
+ - **Language(s):** English
36
+ - **Base model:** meta-llama/Llama-3.2-8B-Instruct
37
+ - **Training approach:** Parameter-Efficient Fine-Tuning (PEFT) using QLoRA
38
+ - **License:** MIT
39
+
40
+ ## Intended Use
41
+
42
+ This model is intended to serve as a protective tool for:
43
+
44
+ - Content moderation teams
45
+ - Platform safety engineers
46
+ - Organizations focused on child and vulnerable adult safety online
47
+ - Researchers studying digital safety and online predatory behavior
48
+
49
+ ### Primary intended uses
50
+
51
+ - Detecting potentially harmful interactions in text messages
52
+ - Providing explanations for why certain messages contain predatory elements
53
+ - Assisting human moderators in identifying concerning patterns
54
+ - Supporting research on digital safety
55
+
56
+ ### Primary intended users
57
+
58
+ - Content moderation teams
59
+ - Digital safety professionals
60
+ - Platform trust & safety teams
61
+ - Child protection services
62
+ - Safety-focused researchers
63
+
64
+ ## How to Use
65
+
66
+ You can use the model via the Hugging Face `transformers` library:
67
+
68
+ ```python
69
+ from transformers import AutoModelForCausalLM, AutoTokenizer
70
+
71
+ # Load model and tokenizer
72
+ model_path = "safecircleia/heaven1-base"
73
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
74
+ model = AutoModelForCausalLM.from_pretrained(model_path)
75
+
76
+ # Message to analyze
77
+ message_to_analyze = "Hey, I know we just met but I feel like we have a special connection. Don't tell your parents about our chats, they wouldn't understand. Can you send me a picture of yourself?"
78
+
79
+ # Format the prompt
80
+ prompt = f"""<|system|>
81
+ You are Heaven, an AI designed to detect predatory behavior in text messages. Analyze the following message and determine if it contains predatory behavior. Provide a detailed explanation for your assessment.
82
+ <|user|>
83
+ {message_to_analyze}
84
+ <|assistant|>
85
+ """
86
+
87
+ # Generate response
88
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
89
+ outputs = model.generate(inputs["input_ids"], max_new_tokens=512)
90
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
+
92
+ print(response)
93
+ ```
94
+
95
+ ## Training Details
96
+
97
+ - **Training data:** The model was trained on a carefully curated synthetic dataset representing various predatory conversation patterns
98
+ - **Training procedure:** Fine-tuned using QLoRA to adapt Llama 3.2's capabilities to predatory text detection
99
+ - **Hyperparameters:**
100
+ - LoRA rank: 16
101
+ - LoRA alpha: 32
102
+ - Learning rate: 2e-5
103
+ - Batch size: 1 with gradient accumulation steps of 8
104
+ - Training epochs: 3
105
+ - Maximum sequence length: 4096
106
+
107
+ ## Evaluation Results
108
+
109
+ Performance metrics on test dataset:
110
+
111
+ | Metric | Score |
112
+ |--------|-------|
113
+ | Accuracy | [INSERT VALUE] |
114
+ | Precision | [INSERT VALUE] |
115
+ | Recall | [INSERT VALUE] |
116
+ | F1 | [INSERT VALUE] |
117
+
118
+ ## Limitations & Biases
119
+
120
+ ### Limitations
121
+
122
+ - The model detects patterns based on its training data and may miss novel predatory tactics
123
+ - Performance may vary across different cultural contexts and communication styles
124
+ - False positives and false negatives are possible; human review is recommended
125
+
126
+ ### Recommendations
127
+
128
+ - Do not use as the sole decision-maker for safety-critical applications
129
+ - Always combine with human review for best results
130
+ - Consider cultural and contextual factors when interpreting results
131
+ - Regularly evaluate and update the model as predatory tactics evolve
132
+
133
+ ## Ethical Considerations
134
+
135
+ This model is designed to help create safer digital environments, particularly for vulnerable individuals like children. However, it should be used responsibly:
136
+
137
+ - Respect privacy and obtain appropriate consent when analyzing communications
138
+ - Be transparent about the use of AI detection systems
139
+ - Do not use the model to create or refine predatory language patterns
140
+ - Consider the impact of false positives on legitimate communications
141
+
142
+ ## Additional Information
143
+
144
+ For more information, research, or to contribute to the development of digital safety tools, visit [SafeCircleIA website or contact information].
145
+
146
+ ## Citation
147
+
148
+ ```
149
+ @misc{heaven1-base-2025,
150
+ author = {SafeCircleIA},
151
+ title = {Heaven1-base: Guardian - Predatory Behavior Detection Model},
152
+ year = {2024},
153
+ publisher = {Hugging Face},
154
+ howpublished = {\url{https://huggingface.co/safecircleia/heaven1-base-guardian}}
155
+ }
156
+ ```
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.36.0
2
+ datasets>=2.14.0
3
+ torch>=2.0.0
4
+ torchvision>=0.15.0
5
+ accelerate>=0.20.0
6
+ bitsandbytes>=0.40.0
7
+ peft>=0.4.0
8
+ trl>=0.7.0
9
+ scipy>=1.10.0
10
+ numpy>=1.24.0
11
+ wandb>=0.15.0
12
+ scikit-learn>=1.2.0
13
+ tqdm>=4.65.0
14
+ jsonlines>=3.1.0
15
+ sentencepiece>=0.1.99
16
+ protobuf>=3.20.0
17
+ einops>=0.6.0
18
+ pyyaml>=6.0
run_heaven.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import yaml
4
+ import subprocess
5
+
6
+ def load_config(config_path):
7
+ """Load configuration from YAML file."""
8
+ with open(config_path, "r") as f:
9
+ return yaml.safe_load(f)
10
+
11
+ def create_dataset(config):
12
+ """Create dataset using the configuration."""
13
+ dataset_config = config["dataset"]
14
+
15
+ cmd = [
16
+ "python", "create_dataset.py",
17
+ "--size", str(dataset_config["size"]),
18
+ "--ratio", str(dataset_config["predatory_ratio"]),
19
+ "--output", dataset_config["output_path"]
20
+ ]
21
+
22
+ print(f"Creating dataset with {dataset_config['size']} examples...")
23
+ result = subprocess.run(cmd)
24
+ if result.returncode != 0:
25
+ print("Dataset creation failed!")
26
+ return False
27
+
28
+ print(f"Dataset created successfully at {dataset_config['output_path']}")
29
+ return True
30
+
31
+ def finetune_model(config):
32
+ """Fine-tune the model using the configuration."""
33
+ dataset_config = config["dataset"]
34
+ model_config = config["model"]
35
+ training_config = config["training"]
36
+ peft_config = config["peft"]
37
+ precision_config = config["precision"]
38
+ logging_config = config["logging"]
39
+
40
+ cmd = [
41
+ "python", "finetune_heaven.py",
42
+ "--model_name_or_path", model_config["name_or_path"],
43
+ "--dataset_path", dataset_config["output_path"],
44
+ "--output_dir", model_config["output_dir"],
45
+ "--num_epochs", str(training_config["num_epochs"]),
46
+ "--batch_size", str(training_config["batch_size"]),
47
+ "--gradient_accumulation_steps", str(training_config["gradient_accumulation_steps"]),
48
+ "--learning_rate", str(training_config["learning_rate"]),
49
+ "--weight_decay", str(training_config["weight_decay"]),
50
+ "--max_grad_norm", str(training_config["max_grad_norm"]),
51
+ "--warmup_ratio", str(training_config["warmup_ratio"]),
52
+ "--eval_ratio", str(training_config["eval_ratio"]),
53
+ "--max_seq_length", str(training_config["max_seq_length"]),
54
+ "--logging_steps", str(logging_config["logging_steps"]),
55
+ "--eval_steps", str(logging_config["eval_steps"]),
56
+ "--save_steps", str(logging_config["save_steps"]),
57
+ "--run_name", logging_config["run_name"],
58
+ "--compute_dtype", precision_config["compute_dtype"]
59
+ ]
60
+
61
+ # Add boolean flags
62
+ if peft_config["use_lora"]:
63
+ cmd.append("--use_lora")
64
+ if peft_config["use_qlora"]:
65
+ cmd.append("--use_qlora")
66
+ if precision_config["fp16"]:
67
+ cmd.append("--fp16")
68
+ if precision_config["bf16"]:
69
+ cmd.append("--bf16")
70
+ if logging_config["use_wandb"]:
71
+ cmd.append("--use_wandb")
72
+
73
+ # Add LoRA parameters
74
+ cmd.extend(["--lora_r", str(peft_config["lora_r"])])
75
+ cmd.extend(["--lora_alpha", str(peft_config["lora_alpha"])])
76
+ cmd.extend(["--lora_dropout", str(peft_config["lora_dropout"])])
77
+
78
+ print("Starting fine-tuning process...")
79
+ result = subprocess.run(cmd)
80
+ if result.returncode != 0:
81
+ print("Fine-tuning failed!")
82
+ return False
83
+
84
+ print(f"Fine-tuning completed successfully! Model saved to {model_config['output_dir']}")
85
+ return True
86
+
87
+ def main():
88
+ parser = argparse.ArgumentParser(description="Run the Heaven fine-tuning pipeline")
89
+ parser.add_argument("--config", type=str, default="config.yaml", help="Path to the configuration file")
90
+ parser.add_argument("--skip-dataset", action="store_true", help="Skip dataset creation step")
91
+ args = parser.parse_args()
92
+
93
+ print(f"Loading configuration from {args.config}...")
94
+ config = load_config(args.config)
95
+
96
+ # Create necessary directories
97
+ os.makedirs(os.path.dirname(config["dataset"]["output_path"]), exist_ok=True)
98
+ os.makedirs(config["model"]["output_dir"], exist_ok=True)
99
+
100
+ # Create dataset if not skipped
101
+ if not args.skip_dataset:
102
+ success = create_dataset(config)
103
+ if not success:
104
+ return
105
+ else:
106
+ print("Skipping dataset creation...")
107
+
108
+ # Fine-tune the model
109
+ finetune_model(config)
110
+
111
+ if __name__ == "__main__":
112
+ main()
upload_to_hub.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import shutil
4
+ from huggingface_hub import HfApi, create_repo, upload_folder
5
+
6
+ def upload_model(args):
7
+ """Upload the fine-tuned model to Hugging Face Hub."""
8
+ print(f"Preparing to upload Heaven1-base Guardian model to {args.org}/{args.model_id}")
9
+
10
+ # Create a temporary directory for preparing the model
11
+ temp_dir = "temp_upload"
12
+ os.makedirs(temp_dir, exist_ok=True)
13
+
14
+ # Copy model files
15
+ print("Copying model files...")
16
+ if os.path.exists(args.model_path):
17
+ for item in os.listdir(args.model_path):
18
+ source = os.path.join(args.model_path, item)
19
+ dest = os.path.join(temp_dir, item)
20
+ if os.path.isdir(source):
21
+ shutil.copytree(source, dest, dirs_exist_ok=True)
22
+ else:
23
+ shutil.copy2(source, dest)
24
+ else:
25
+ print(f"Warning: Model directory {args.model_path} not found.")
26
+
27
+ # Copy documentation files
28
+ print("Copying documentation files...")
29
+ docs_files = ["README.md", "model_card.md", "Heaven1-guardian.png"]
30
+ for file in docs_files:
31
+ if os.path.exists(file):
32
+ shutil.copy2(file, os.path.join(temp_dir, file))
33
+
34
+ # Rename model_card.md to README.md for proper display on the Hub
35
+ if os.path.exists(os.path.join(temp_dir, "model_card.md")):
36
+ print("Using model_card.md as the main README for the Hub...")
37
+ if os.path.exists(os.path.join(temp_dir, "README.md")):
38
+ # If both exist, rename the original README to avoid overwriting
39
+ os.rename(os.path.join(temp_dir, "README.md"), os.path.join(temp_dir, "DETAILED_README.md"))
40
+ os.rename(os.path.join(temp_dir, "model_card.md"), os.path.join(temp_dir, "README.md"))
41
+
42
+ # Initialize Hugging Face API
43
+ api = HfApi()
44
+
45
+ # Create repository if it doesn't exist
46
+ try:
47
+ print(f"Creating repository: {args.org}/{args.model_id}")
48
+ create_repo(
49
+ repo_id=f"{args.org}/{args.model_id}",
50
+ token=args.token,
51
+ private=args.private,
52
+ repo_type="model",
53
+ exist_ok=True,
54
+ )
55
+ except Exception as e:
56
+ print(f"Repository creation error (it might already exist): {e}")
57
+
58
+ # Upload model to Hugging Face Hub
59
+ print(f"Uploading files to {args.org}/{args.model_id}...")
60
+ response = upload_folder(
61
+ folder_path=temp_dir,
62
+ repo_id=f"{args.org}/{args.model_id}",
63
+ token=args.token,
64
+ repo_type="model",
65
+ ignore_patterns=[".*", "__pycache__/*", "temp_upload/*"],
66
+ )
67
+
68
+ print(f"Upload complete! Model available at: https://huggingface.co/{args.org}/{args.model_id}")
69
+
70
+ # Clean up
71
+ if not args.keep_temp:
72
+ print("Cleaning up temporary directory...")
73
+ shutil.rmtree(temp_dir)
74
+
75
+ return response
76
+
77
+ if __name__ == "__main__":
78
+ parser = argparse.ArgumentParser(description="Upload Heaven1-base Guardian model to Hugging Face Hub")
79
+ parser.add_argument("--model_path", type=str, default="./heaven1-base-8b",
80
+ help="Path to the fine-tuned model directory")
81
+ parser.add_argument("--org", type=str, default="safecircleia",
82
+ help="Organization name on Hugging Face Hub")
83
+ parser.add_argument("--model_id", type=str, default="heaven1-base-guardian",
84
+ help="Model ID for the repository")
85
+ parser.add_argument("--token", type=str, required=True,
86
+ help="Hugging Face authentication token")
87
+ parser.add_argument("--private", action="store_true",
88
+ help="Whether to make the repository private")
89
+ parser.add_argument("--keep_temp", action="store_true",
90
+ help="Keep temporary upload directory after completion")
91
+
92
+ args = parser.parse_args()
93
+
94
+ upload_model(args)