guardrails-final / llm_clients /finetuned_guard.py
zazaman's picture
Add multilingual translation support with Qwen3-0.6B-GGUF and optimize for Hugging Face Spaces deployment
a2e1879
raw
history blame
5.91 kB
# llm_clients/finetuned_guard.py
from typing import Generator, Any, Dict, Optional
import json
from .base import LlmClient
class FinetunedGuardClient(LlmClient):
"""LLM client for finetuned model for safe/unsafe classification using zazaman/fmb."""
def __init__(self, config_dict: Dict[str, Any], system_prompt: str, shared_components: Optional[Dict[str, Any]] = None):
super().__init__(config_dict, system_prompt)
# If shared components are provided, use them instead of loading our own
if shared_components:
print(f" πŸ”— FinetunedGuardClient: Using shared model components")
self.model = shared_components["model"]
self.tokenizer = shared_components["tokenizer"]
self.classifier = shared_components["classifier"]
self.transformers_available = True
return
# Fallback: Load our own model (this should rarely happen now)
print(f" ⚠️ FinetunedGuardClient: Loading independent model (shared components not available)")
try:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
# Disable torch compilation globally
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.disable = True
self.transformers_available = True
except ImportError:
raise ImportError(
"transformers library is required for FinetunedGuardClient. "
"Install it with: pip install transformers torch"
)
except AttributeError:
# If torch._dynamo doesn't exist in older versions, that's fine
self.transformers_available = True
# Get model name from config or use default
model_name = config_dict.get("model_name", "zazaman/fmb")
print(f"πŸ”„ Loading finetuned model: {model_name}")
try:
# Disable torch compile optimizations for lightweight CPU-only devices
import os
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"
# Disable TensorFlow oneDNN warnings
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use float32 for CPU
device_map=None # Disable automatic device mapping
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Explicitly disable compilation on the model
if hasattr(self.model, '_compiler_config'):
self.model._compiler_config = None
# Use CPU device for lightweight operation
device = "cpu"
self.model = self.model.to(device)
self.classifier = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
device=device,
framework="pt",
torch_dtype=torch.float32
)
print(f"βœ… Finetuned Guard Client initialized successfully.")
print(f" Model: {model_name}")
print(f" Device: {device}")
except Exception as e:
raise RuntimeError(f"Failed to load finetuned model {model_name}: {e}")
def generate_content(self, prompt: str) -> str:
"""
Classifies the prompt as safe or unsafe using the finetuned model.
Returns a JSON response compatible with the existing AI detection system.
"""
try:
# Classify the prompt
result = self.classifier(prompt)[0]
# Extract the prediction and confidence
predicted_label = result['label']
confidence_score = result['score']
# Determine safety based on the model's prediction
# Assuming 'SAFE' and 'UNSAFE' are the labels from your fine-tuned model
is_safe = predicted_label.upper() == 'SAFE'
# Create response in the expected format
response_data = {
"safety_status": "safe" if is_safe else "unsafe",
"attack_type": "none" if is_safe else "prompt_injection",
"confidence": confidence_score,
"is_safe": is_safe,
"model_used": "zazaman/fmb",
"reason": f"Model predicted '{predicted_label}' with {confidence_score:.2%} confidence"
}
return json.dumps(response_data)
except Exception as e:
# Return error response in JSON format
error_response = {
"safety_status": "error",
"attack_type": "unknown",
"confidence": 0.0,
"is_safe": False,
"model_used": "zazaman/fmb",
"reason": f"Classification error: {str(e)}"
}
return json.dumps(error_response)
def generate_content_stream(self, prompt: str) -> Generator[str, None, None]:
"""
Streaming is not applicable for classification tasks.
Returns the classification result as a single chunk.
"""
yield self.generate_content(prompt)
def _generate_content_impl(self, prompt: str) -> str:
"""Implementation for base class compatibility."""
return self.generate_content(prompt)
def _generate_content_stream_impl(self, prompt: str) -> Generator[str, None, None]:
"""Implementation for base class compatibility."""
return self.generate_content_stream(prompt)