File size: 5,910 Bytes
a2e1879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# llm_clients/finetuned_guard.py
from typing import Generator, Any, Dict, Optional
import json
from .base import LlmClient

class FinetunedGuardClient(LlmClient):
    """LLM client for finetuned model for safe/unsafe classification using zazaman/fmb."""

    def __init__(self, config_dict: Dict[str, Any], system_prompt: str, shared_components: Optional[Dict[str, Any]] = None):
        super().__init__(config_dict, system_prompt)
        
        # If shared components are provided, use them instead of loading our own
        if shared_components:
            print(f"   🔗 FinetunedGuardClient: Using shared model components")
            self.model = shared_components["model"]
            self.tokenizer = shared_components["tokenizer"]
            self.classifier = shared_components["classifier"]
            self.transformers_available = True
            return
        
        # Fallback: Load our own model (this should rarely happen now)
        print(f"   ⚠️  FinetunedGuardClient: Loading independent model (shared components not available)")
        
        try:
            from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
            import torch
            
            # Disable torch compilation globally
            torch._dynamo.config.suppress_errors = True
            torch._dynamo.config.disable = True
            
            self.transformers_available = True
        except ImportError:
            raise ImportError(
                "transformers library is required for FinetunedGuardClient. "
                "Install it with: pip install transformers torch"
            )
        except AttributeError:
            # If torch._dynamo doesn't exist in older versions, that's fine
            self.transformers_available = True
        
        # Get model name from config or use default
        model_name = config_dict.get("model_name", "zazaman/fmb")
        print(f"🔄 Loading finetuned model: {model_name}")
        
        try:
            # Disable torch compile optimizations for lightweight CPU-only devices
            import os
            os.environ["TORCH_COMPILE_DISABLE"] = "1"
            os.environ["TORCHDYNAMO_DISABLE"] = "1"
            # Disable TensorFlow oneDNN warnings
            os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
            
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_name,
                torch_dtype=torch.float32,  # Use float32 for CPU
                device_map=None  # Disable automatic device mapping
            )
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            
            # Explicitly disable compilation on the model
            if hasattr(self.model, '_compiler_config'):
                self.model._compiler_config = None
            
            # Use CPU device for lightweight operation
            device = "cpu"
            self.model = self.model.to(device)
            
            self.classifier = pipeline(
                "text-classification", 
                model=self.model, 
                tokenizer=self.tokenizer,
                device=device,
                framework="pt",
                torch_dtype=torch.float32
            )
            
            print(f"✅ Finetuned Guard Client initialized successfully.")
            print(f"   Model: {model_name}")
            print(f"   Device: {device}")
            
        except Exception as e:
            raise RuntimeError(f"Failed to load finetuned model {model_name}: {e}")

    def generate_content(self, prompt: str) -> str:
        """
        Classifies the prompt as safe or unsafe using the finetuned model.
        Returns a JSON response compatible with the existing AI detection system.
        """
        try:
            # Classify the prompt
            result = self.classifier(prompt)[0]
            
            # Extract the prediction and confidence
            predicted_label = result['label']
            confidence_score = result['score']
            
            # Determine safety based on the model's prediction
            # Assuming 'SAFE' and 'UNSAFE' are the labels from your fine-tuned model
            is_safe = predicted_label.upper() == 'SAFE'
            
            # Create response in the expected format
            response_data = {
                "safety_status": "safe" if is_safe else "unsafe",
                "attack_type": "none" if is_safe else "prompt_injection",
                "confidence": confidence_score,
                "is_safe": is_safe,
                "model_used": "zazaman/fmb",
                "reason": f"Model predicted '{predicted_label}' with {confidence_score:.2%} confidence"
            }
            
            return json.dumps(response_data)
            
        except Exception as e:
            # Return error response in JSON format
            error_response = {
                "safety_status": "error",
                "attack_type": "unknown",
                "confidence": 0.0,
                "is_safe": False,
                "model_used": "zazaman/fmb",
                "reason": f"Classification error: {str(e)}"
            }
            return json.dumps(error_response)

    def generate_content_stream(self, prompt: str) -> Generator[str, None, None]:
        """
        Streaming is not applicable for classification tasks.
        Returns the classification result as a single chunk.
        """
        yield self.generate_content(prompt)
    
    def _generate_content_impl(self, prompt: str) -> str:
        """Implementation for base class compatibility."""
        return self.generate_content(prompt)

    def _generate_content_stream_impl(self, prompt: str) -> Generator[str, None, None]:
        """Implementation for base class compatibility."""
        return self.generate_content_stream(prompt)