Spaces:
Sleeping
Sleeping
File size: 5,910 Bytes
a2e1879 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# llm_clients/finetuned_guard.py
from typing import Generator, Any, Dict, Optional
import json
from .base import LlmClient
class FinetunedGuardClient(LlmClient):
"""LLM client for finetuned model for safe/unsafe classification using zazaman/fmb."""
def __init__(self, config_dict: Dict[str, Any], system_prompt: str, shared_components: Optional[Dict[str, Any]] = None):
super().__init__(config_dict, system_prompt)
# If shared components are provided, use them instead of loading our own
if shared_components:
print(f" 🔗 FinetunedGuardClient: Using shared model components")
self.model = shared_components["model"]
self.tokenizer = shared_components["tokenizer"]
self.classifier = shared_components["classifier"]
self.transformers_available = True
return
# Fallback: Load our own model (this should rarely happen now)
print(f" ⚠️ FinetunedGuardClient: Loading independent model (shared components not available)")
try:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
# Disable torch compilation globally
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.disable = True
self.transformers_available = True
except ImportError:
raise ImportError(
"transformers library is required for FinetunedGuardClient. "
"Install it with: pip install transformers torch"
)
except AttributeError:
# If torch._dynamo doesn't exist in older versions, that's fine
self.transformers_available = True
# Get model name from config or use default
model_name = config_dict.get("model_name", "zazaman/fmb")
print(f"🔄 Loading finetuned model: {model_name}")
try:
# Disable torch compile optimizations for lightweight CPU-only devices
import os
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"
# Disable TensorFlow oneDNN warnings
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use float32 for CPU
device_map=None # Disable automatic device mapping
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Explicitly disable compilation on the model
if hasattr(self.model, '_compiler_config'):
self.model._compiler_config = None
# Use CPU device for lightweight operation
device = "cpu"
self.model = self.model.to(device)
self.classifier = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
device=device,
framework="pt",
torch_dtype=torch.float32
)
print(f"✅ Finetuned Guard Client initialized successfully.")
print(f" Model: {model_name}")
print(f" Device: {device}")
except Exception as e:
raise RuntimeError(f"Failed to load finetuned model {model_name}: {e}")
def generate_content(self, prompt: str) -> str:
"""
Classifies the prompt as safe or unsafe using the finetuned model.
Returns a JSON response compatible with the existing AI detection system.
"""
try:
# Classify the prompt
result = self.classifier(prompt)[0]
# Extract the prediction and confidence
predicted_label = result['label']
confidence_score = result['score']
# Determine safety based on the model's prediction
# Assuming 'SAFE' and 'UNSAFE' are the labels from your fine-tuned model
is_safe = predicted_label.upper() == 'SAFE'
# Create response in the expected format
response_data = {
"safety_status": "safe" if is_safe else "unsafe",
"attack_type": "none" if is_safe else "prompt_injection",
"confidence": confidence_score,
"is_safe": is_safe,
"model_used": "zazaman/fmb",
"reason": f"Model predicted '{predicted_label}' with {confidence_score:.2%} confidence"
}
return json.dumps(response_data)
except Exception as e:
# Return error response in JSON format
error_response = {
"safety_status": "error",
"attack_type": "unknown",
"confidence": 0.0,
"is_safe": False,
"model_used": "zazaman/fmb",
"reason": f"Classification error: {str(e)}"
}
return json.dumps(error_response)
def generate_content_stream(self, prompt: str) -> Generator[str, None, None]:
"""
Streaming is not applicable for classification tasks.
Returns the classification result as a single chunk.
"""
yield self.generate_content(prompt)
def _generate_content_impl(self, prompt: str) -> str:
"""Implementation for base class compatibility."""
return self.generate_content(prompt)
def _generate_content_stream_impl(self, prompt: str) -> Generator[str, None, None]:
"""Implementation for base class compatibility."""
return self.generate_content_stream(prompt) |