File size: 21,183 Bytes
54c5666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
"""

Constitutional AI and Safety Integration

Implements Claude-style constitutional training and safety mechanisms

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from enum import Enum
import re
import logging

logger = logging.getLogger(__name__)


class HarmCategory(Enum):
    """Categories of potential harm"""
    VIOLENCE = "violence"
    HATE_SPEECH = "hate_speech"
    SEXUAL_CONTENT = "sexual_content"
    SELF_HARM = "self_harm"
    PII = "personally_identifiable_information"
    DECEPTION = "deception"
    ILLEGAL = "illegal_activity"
    MEDICAL = "medical_advice"
    FINANCIAL = "financial_advice"
    MANIPULATION = "manipulation"


@dataclass
class SafetyAssessment:
    """Safety assessment results"""
    is_safe: bool
    harm_scores: Dict[HarmCategory, float]
    overall_risk: float
    flagged_categories: List[HarmCategory]
    suggested_revision: Optional[str] = None
    explanation: Optional[str] = None


@dataclass
class ConstitutionalPrinciple:
    """A constitutional principle for AI behavior"""
    principle: str
    category: str
    weight: float = 1.0
    examples: List[str] = None


class HarmPredictor(nn.Module):
    """Multi-label harm classifier for content safety"""
    
    def __init__(self, hidden_dim: int = 768, num_categories: int = 10):
        super().__init__()
        
        self.categories = list(HarmCategory)
        
        # Shared encoder
        self.encoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )
        
        # Per-category classifiers
        self.category_heads = nn.ModuleDict({
            category.value: nn.Sequential(
                nn.Linear(hidden_dim // 2, 128),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(128, 1),
                nn.Sigmoid()
            )
            for category in HarmCategory
        })
        
        # Overall safety scorer
        self.safety_head = nn.Sequential(
            nn.Linear(hidden_dim // 2 + num_categories, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        # PII detector patterns (simplified)
        self.pii_patterns = [
            r'\b\d{3}-\d{2}-\d{4}\b',  # SSN
            r'\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b',  # Email
            r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b',  # Phone
            r'\b\d{16}\b',  # Credit card
        ]
        
    def detect_pii(self, text: str) -> float:
        """Simple PII detection"""
        pii_score = 0.0
        for pattern in self.pii_patterns:
            if re.search(pattern, text, re.IGNORECASE):
                pii_score = 1.0
                break
        return pii_score
    
    def forward(self, hidden_states: torch.Tensor, text: Optional[str] = None) -> SafetyAssessment:
        """Assess content safety"""
        # Pool hidden states
        if len(hidden_states.shape) == 3:
            pooled = hidden_states.mean(dim=1)  # [batch, seq_len, hidden] -> [batch, hidden]
        else:
            pooled = hidden_states
        
        # Encode
        encoded = self.encoder(pooled)
        
        # Get per-category scores
        harm_scores = {}
        for category in HarmCategory:
            score = self.category_heads[category.value](encoded)
            harm_scores[category] = score.squeeze(-1).item() if score.numel() == 1 else score.squeeze(-1)
        
        # Check for PII if text provided
        if text and HarmCategory.PII in harm_scores:
            pii_score = self.detect_pii(text)
            harm_scores[HarmCategory.PII] = max(harm_scores[HarmCategory.PII], pii_score)
        
        # Aggregate scores for overall safety
        # Ensure all tensors are on the same device as encoded
        device = encoded.device
        score_tensor = torch.stack([
            harm_scores[cat] if isinstance(harm_scores[cat], torch.Tensor) else torch.tensor(harm_scores[cat], device=device)
            for cat in HarmCategory
        ])
        
        if len(score_tensor.shape) == 1:
            score_tensor = score_tensor.unsqueeze(0)
        
        # Ensure score_tensor is on the correct device
        score_tensor = score_tensor.to(device)
        
        safety_input = torch.cat([encoded, score_tensor], dim=-1)
        overall_safety = self.safety_head(safety_input).squeeze(-1)
        
        # Determine if safe (threshold-based)
        threshold = 0.7
        is_safe = overall_safety.item() > threshold if overall_safety.numel() == 1 else (overall_safety > threshold).all()
        
        # Flag categories above threshold
        category_threshold = 0.5
        flagged = [
            cat for cat, score in harm_scores.items()
            if (score.item() if isinstance(score, torch.Tensor) else score) > category_threshold
        ]
        
        return SafetyAssessment(
            is_safe=bool(is_safe),
            harm_scores={k: (v.item() if isinstance(v, torch.Tensor) else v) for k, v in harm_scores.items()},
            overall_risk=1.0 - (overall_safety.item() if overall_safety.numel() == 1 else overall_safety.mean().item()),
            flagged_categories=flagged
        )


class SelfCritic(nn.Module):
    """Self-critique module for generating improvements"""
    
    def __init__(self, base_model: nn.Module, hidden_dim: int = 4096):
        super().__init__()
        
        self.base_model = base_model
        
        # Critique generator
        self.critique_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Revision generator
        self.revision_head = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def generate_critique(

        self,

        response: torch.Tensor,

        context: torch.Tensor,

        principles: List[ConstitutionalPrinciple]

    ) -> torch.Tensor:
        """Generate critique of response"""
        # Concatenate response and context
        combined = torch.cat([response, context], dim=-1)
        
        # Generate critique embedding
        critique_emb = self.critique_head(combined)
        
        # Apply principle weighting
        for principle in principles:
            # Simplified - would use actual principle embeddings
            critique_emb = critique_emb * principle.weight
        
        return critique_emb
    
    def generate_revision(

        self,

        response: torch.Tensor,

        critique: torch.Tensor,

        context: torch.Tensor

    ) -> torch.Tensor:
        """Generate revised response based on critique"""
        # Combine all inputs
        combined = torch.cat([response, critique, context], dim=-1)
        
        # Generate revision
        revision = self.revision_head(combined)
        
        return revision
    
    def forward(

        self,

        response: torch.Tensor,

        context: torch.Tensor,

        principles: List[ConstitutionalPrinciple]

    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate critique and revision"""
        critique = self.generate_critique(response, context, principles)
        revision = self.generate_revision(response, critique, context)
        
        return critique, revision


class ValueVerifier(nn.Module):
    """Verify alignment with constitutional values"""
    
    def __init__(self, num_principles: int = 20, hidden_dim: int = 768):
        super().__init__()
        
        # Principle embeddings
        self.principle_embeddings = nn.Embedding(num_principles, hidden_dim)
        
        # Alignment scorer
        self.alignment_scorer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        # Default principles
        self.principles = self._initialize_principles()
        
    def _initialize_principles(self) -> List[ConstitutionalPrinciple]:
        """Initialize default constitutional principles"""
        return [
            ConstitutionalPrinciple(
                principle="Be helpful, harmless, and honest",
                category="core",
                weight=2.0
            ),
            ConstitutionalPrinciple(
                principle="Avoid generating harmful, offensive, or inappropriate content",
                category="safety",
                weight=1.5
            ),
            ConstitutionalPrinciple(
                principle="Respect user privacy and do not request or expose PII",
                category="privacy",
                weight=1.5
            ),
            ConstitutionalPrinciple(
                principle="Provide accurate information and acknowledge uncertainty",
                category="truthfulness",
                weight=1.3
            ),
            ConstitutionalPrinciple(
                principle="Be respectful and considerate in all interactions",
                category="respect",
                weight=1.2
            ),
            ConstitutionalPrinciple(
                principle="Do not provide medical, legal, or financial advice",
                category="professional",
                weight=1.4
            ),
            ConstitutionalPrinciple(
                principle="Refuse requests for illegal or harmful activities",
                category="legal",
                weight=2.0
            ),
            ConstitutionalPrinciple(
                principle="Be transparent about limitations and capabilities",
                category="transparency",
                weight=1.1
            ),
        ]
    
    def check_alignment(

        self,

        response: torch.Tensor,

        principle_idx: int

    ) -> float:
        """Check alignment with specific principle"""
        # Get principle embedding
        principle_emb = self.principle_embeddings(torch.tensor(principle_idx))
        
        # Pool response if needed
        if len(response.shape) == 3:
            response = response.mean(dim=1)
        
        # Combine and score
        combined = torch.cat([response, principle_emb.unsqueeze(0)], dim=-1)
        alignment_score = self.alignment_scorer(combined)
        
        return alignment_score.item()
    
    def forward(self, response: torch.Tensor) -> Dict[str, float]:
        """Check alignment with all principles"""
        alignments = {}
        
        for idx, principle in enumerate(self.principles):
            score = self.check_alignment(response, idx)
            alignments[principle.category] = score * principle.weight
        
        return alignments


class ConstitutionalReasoningCore(nn.Module):
    """Main Constitutional AI module"""
    
    def __init__(

        self,

        base_model: nn.Module,

        config: Dict[str, Any],

        enable_critique: bool = True,

        enable_safety: bool = True

    ):
        super().__init__()
        
        self.base_model = base_model
        self.config = config
        self.enable_critique = enable_critique
        self.enable_safety = enable_safety
        
        hidden_dim = config.get('hidden_dim', 4096)
        
        # Components
        self.harm_predictor = HarmPredictor(hidden_dim=hidden_dim)
        self.self_critic = SelfCritic(base_model, hidden_dim=hidden_dim) if enable_critique else None
        self.value_verifier = ValueVerifier(hidden_dim=hidden_dim)
        
        # Constitutional training loss weight
        self.constitutional_weight = config.get('constitutional_weight', 0.1)
        
        # Safety thresholds
        self.safety_threshold = config.get('safety_threshold', 0.7)
        self.revision_threshold = config.get('revision_threshold', 0.5)
        
    def assess_safety(

        self,

        hidden_states: torch.Tensor,

        text: Optional[str] = None

    ) -> SafetyAssessment:
        """Assess content safety"""
        return self.harm_predictor(hidden_states, text)
    
    def critique_and_revise(

        self,

        response: torch.Tensor,

        context: torch.Tensor,

        safety_assessment: Optional[SafetyAssessment] = None

    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        """Generate critique and revision"""
        if not self.enable_critique or self.self_critic is None:
            return response, response, {}
        
        # Get principles based on safety assessment
        principles = self.value_verifier.principles
        if safety_assessment and safety_assessment.flagged_categories:
            # Prioritize relevant principles
            relevant_principles = [p for p in principles if any(
                cat.value.lower() in p.principle.lower() 
                for cat in safety_assessment.flagged_categories
            )]
            principles = relevant_principles or principles
        
        # Generate critique and revision
        critique, revision = self.self_critic(response, context, principles)
        
        # Check alignment of revision
        alignment_scores = self.value_verifier(revision)
        
        info = {
            'critique_generated': True,
            'alignment_scores': alignment_scores,
            'revision_quality': sum(alignment_scores.values()) / len(alignment_scores)
        }
        
        return response, revision, info
    
    def forward(

        self,

        input_ids: torch.Tensor,

        labels: Optional[torch.Tensor] = None,

        generate_critique: bool = True,

        enforce_safety: bool = True,

        **kwargs

    ) -> Dict[str, Any]:
        """Forward pass with constitutional reasoning"""
        
        # Accept externally computed hidden_states but do not pass to base_model
        provided_hidden_states = kwargs.pop('hidden_states', None)
        
        # Get base model output (with cleaned kwargs)
        base_output = self.base_model(input_ids, labels=labels, **kwargs)
        
        # Extract hidden states
        hidden_states = provided_hidden_states if provided_hidden_states is not None else base_output.get('hidden_states')
        if hidden_states is None:
            # Use logits as proxy
            hidden_states = base_output['logits']
        
        # Safety assessment
        safety_assessment = None
        if self.enable_safety and enforce_safety:
            safety_assessment = self.assess_safety(hidden_states)
            
            # Block unsafe content
            if not safety_assessment.is_safe:
                logger.warning(f"Unsafe content detected: {safety_assessment.flagged_categories}")
                
                # Create safe alternative response
                safe_response = self._generate_safe_response(safety_assessment)
                base_output['logits'] = safe_response
                base_output['safety_blocked'] = True
                base_output['safety_assessment'] = safety_assessment
                
                return base_output
        
        # Self-critique and revision
        revision_info = {}
        if generate_critique and safety_assessment:
            if safety_assessment.overall_risk > self.revision_threshold:
                # Need revision
                original, revised, revision_info = self.critique_and_revise(
                    hidden_states,
                    hidden_states,  # Using same as context for simplicity
                    safety_assessment
                )
                
                # Update output with revision
                base_output['revised_hidden_states'] = revised
                base_output['revision_info'] = revision_info
        
        # Calculate constitutional loss if training
        if labels is not None and self.training:
            constitutional_loss = self._calculate_constitutional_loss(
                hidden_states,
                safety_assessment,
                revision_info
            )
            
            # Add to main loss
            if base_output.get('loss') is not None:
                base_output['loss'] = base_output['loss'] + self.constitutional_weight * constitutional_loss
            else:
                base_output['loss'] = constitutional_loss
            
            base_output['constitutional_loss'] = constitutional_loss
        
        # Add constitutional info to output
        base_output['constitutional_info'] = {
            'safety_assessment': safety_assessment.__dict__ if safety_assessment else None,
            'revision_info': revision_info,
            'principles_checked': len(self.value_verifier.principles),
        }
        
        return base_output
    
    def _generate_safe_response(self, safety_assessment: SafetyAssessment) -> torch.Tensor:
        """Generate a safe alternative response"""
        # Placeholder - would generate appropriate safe response
        batch_size = 1
        seq_len = 100
        vocab_size = self.base_model.config.vocab_size
        
        # Create a generic safe response embedding
        safe_response = torch.zeros((batch_size, seq_len, vocab_size))
        
        # Set high probability for safe tokens (simplified)
        safe_tokens = [0, 1, 2]  # Would be actual safe token IDs
        for token in safe_tokens:
            safe_response[:, :, token] = 0.3
        
        return safe_response
    
    def _calculate_constitutional_loss(

        self,

        hidden_states: torch.Tensor,

        safety_assessment: Optional[SafetyAssessment],

        revision_info: Dict[str, Any]

    ) -> torch.Tensor:
        """Calculate loss for constitutional training"""
        total_loss = torch.tensor(0.0, device=hidden_states.device)
        
        # Safety loss
        if safety_assessment:
            # Penalize high harm scores
            harm_loss = sum(safety_assessment.harm_scores.values()) / len(safety_assessment.harm_scores)
            total_loss += harm_loss
        
        # Alignment loss
        if revision_info and 'alignment_scores' in revision_info:
            # Reward high alignment
            alignment_loss = 1.0 - (sum(revision_info['alignment_scores'].values()) / 
                                   len(revision_info['alignment_scores']))
            total_loss += alignment_loss
        
        # Value verification loss
        alignment_scores = self.value_verifier(hidden_states)
        value_loss = 1.0 - (sum(alignment_scores.values()) / len(alignment_scores))
        total_loss += value_loss
        
        return total_loss
    
    def train_constitutional(

        self,

        dataloader,

        optimizer,

        num_epochs: int = 3,

        device: str = 'cuda'

    ):
        """Constitutional training loop"""
        self.train()
        
        for epoch in range(num_epochs):
            total_loss = 0
            num_batches = 0
            
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                labels = batch.get('labels', input_ids).to(device)
                
                # Forward pass with constitutional reasoning
                outputs = self.forward(
                    input_ids,
                    labels=labels,
                    generate_critique=True,
                    enforce_safety=True
                )
                
                loss = outputs['loss']
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                num_batches += 1
                
                if num_batches % 100 == 0:
                    logger.info(f"Epoch {epoch}, Batch {num_batches}, "
                               f"Loss: {loss.item():.4f}, "
                               f"Constitutional Loss: {outputs.get('constitutional_loss', 0):.4f}")
            
            avg_loss = total_loss / num_batches
            logger.info(f"Epoch {epoch} completed. Average loss: {avg_loss:.4f}")
        
        return avg_loss