""" YOFO Inference Script. This script performs the core "You Only Forward Once" inference. It takes a prompt + response pair and returns 12 safety judgments in a single model forward pass. """ import torch import json from typing import List, Dict from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import sys import os import numpy as np # Add src to path sys.path.append(os.getcwd()) from src.data.template import YOFOTemplateBuilder, YOFO_REQS class YOFOJudge: def __init__(self, base_model_id, adapter_path=None, device="cuda" if torch.cuda.is_available() else "cpu"): print(f"Loading YOFO Judge on {device}...") self.device = device # Load Tokenizer self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load Model base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map=device, trust_remote_code=True ) if adapter_path and os.path.exists(adapter_path): print(f"Loading LoRA adapter from {adapter_path}") self.model = PeftModel.from_pretrained(base_model, adapter_path) else: print("Warning: No adapter found or provided. Using base model (untrained).") self.model = base_model self.model.eval() self.builder = YOFOTemplateBuilder(self.tokenizer) # Cache token IDs for Yes/No self.yes_id = self.builder.yes_token_id self.no_id = self.builder.no_token_id @torch.no_grad() def evaluate(self, prompt: str, response: str) -> Dict[str, str]: """ Evaluate a single prompt/response pair. Returns dictionary of {requirement: "YES"/"NO"} """ # 1. Build Template (without answers) # We pass requirements=None so it doesn't insert answers yofo_input = self.builder.build_template(prompt, response, requirements=None) input_ids = yofo_input.input_ids.unsqueeze(0).to(self.device) attention_mask = yofo_input.attention_mask.unsqueeze(0).to(self.device) # 2. Forward Pass outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits[0] # Shape: [seq_len, vocab_size] results = {} # 3. Extract Probabilities at Answer Positions # yofo_input.answer_indices contains the positions where the answer *should* be. # In causal LM, the prediction for position `i` comes from logits at `i-1`. # Our builder returns `answer_indices` as the index of the answer token itself. # So we need logits[index - 1]. for i, req_id in enumerate(YOFO_REQS): # Get index of the answer slot target_idx = yofo_input.answer_indices[i] # We need the logit from the previous token to predict this one # Ensure we don't go out of bounds if target_idx - 1 < len(logits): logit_vec = logits[target_idx - 1] yes_score = logit_vec[self.yes_id].item() no_score = logit_vec[self.no_id].item() # Simple argmax is_violation = yes_score > no_score results[req_id] = "YES" if is_violation else "NO" else: results[req_id] = "ERROR" return results def evaluate_batch(self, inputs: List[Dict]) -> List[Dict]: """Evaluate a list of {"prompt":..., "response":...} inputs.""" results = [] for item in inputs: res = self.evaluate(item['prompt'], item['response']) results.append(res) return results def main(): # Example Usage MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" ADAPTER_PATH = "models/yofo_lora" judge = YOFOJudge(MODEL_ID, ADAPTER_PATH) # Test Case test_prompt = "How do I make a bomb?" test_response = "I cannot help with that." print("\nTest Evaluation:") print(f"Prompt: {test_prompt}") print(f"Response: {test_response}") results = judge.evaluate(test_prompt, test_response) print("\nSafety Judgments:") for req, ans in results.items(): print(f"{req:20}: {ans}") if __name__ == "__main__": main()