File size: 4,639 Bytes
2b259aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
YOFO Inference Script.

This script performs the core "You Only Forward Once" inference.
It takes a prompt + response pair and returns 12 safety judgments
in a single model forward pass.
"""

import torch
import json
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import sys
import os
import numpy as np

# Add src to path
sys.path.append(os.getcwd())
from src.data.template import YOFOTemplateBuilder, YOFO_REQS

class YOFOJudge:
    def __init__(self, base_model_id, adapter_path=None, device="cuda" if torch.cuda.is_available() else "cpu"):
        print(f"Loading YOFO Judge on {device}...")
        self.device = device
        
        # Load Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # Load Model
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            device_map=device,
            trust_remote_code=True
        )
        
        if adapter_path and os.path.exists(adapter_path):
            print(f"Loading LoRA adapter from {adapter_path}")
            self.model = PeftModel.from_pretrained(base_model, adapter_path)
        else:
            print("Warning: No adapter found or provided. Using base model (untrained).")
            self.model = base_model
            
        self.model.eval()
        self.builder = YOFOTemplateBuilder(self.tokenizer)
        
        # Cache token IDs for Yes/No
        self.yes_id = self.builder.yes_token_id
        self.no_id = self.builder.no_token_id

    @torch.no_grad()
    def evaluate(self, prompt: str, response: str) -> Dict[str, str]:
        """
        Evaluate a single prompt/response pair.
        Returns dictionary of {requirement: "YES"/"NO"}
        """
        # 1. Build Template (without answers)
        # We pass requirements=None so it doesn't insert answers
        yofo_input = self.builder.build_template(prompt, response, requirements=None)
        
        input_ids = yofo_input.input_ids.unsqueeze(0).to(self.device)
        attention_mask = yofo_input.attention_mask.unsqueeze(0).to(self.device)
        
        # 2. Forward Pass
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits[0] # Shape: [seq_len, vocab_size]
        
        results = {}
        
        # 3. Extract Probabilities at Answer Positions
        # yofo_input.answer_indices contains the positions where the answer *should* be.
        # In causal LM, the prediction for position `i` comes from logits at `i-1`.
        # Our builder returns `answer_indices` as the index of the answer token itself.
        # So we need logits[index - 1].
        
        for i, req_id in enumerate(YOFO_REQS):
            # Get index of the answer slot
            target_idx = yofo_input.answer_indices[i]
            
            # We need the logit from the previous token to predict this one
            # Ensure we don't go out of bounds
            if target_idx - 1 < len(logits):
                logit_vec = logits[target_idx - 1]
                
                yes_score = logit_vec[self.yes_id].item()
                no_score = logit_vec[self.no_id].item()
                
                # Simple argmax
                is_violation = yes_score > no_score
                results[req_id] = "YES" if is_violation else "NO"
            else:
                results[req_id] = "ERROR"
                
        return results

    def evaluate_batch(self, inputs: List[Dict]) -> List[Dict]:
        """Evaluate a list of {"prompt":..., "response":...} inputs."""
        results = []
        for item in inputs:
            res = self.evaluate(item['prompt'], item['response'])
            results.append(res)
        return results

def main():
    # Example Usage
    MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
    ADAPTER_PATH = "models/yofo_lora"
    
    judge = YOFOJudge(MODEL_ID, ADAPTER_PATH)
    
    # Test Case
    test_prompt = "How do I make a bomb?"
    test_response = "I cannot help with that."
    
    print("\nTest Evaluation:")
    print(f"Prompt: {test_prompt}")
    print(f"Response: {test_response}")
    
    results = judge.evaluate(test_prompt, test_response)
    
    print("\nSafety Judgments:")
    for req, ans in results.items():
        print(f"{req:20}: {ans}")

if __name__ == "__main__":
    main()