Spaces:
Sleeping
Sleeping
| # llm_clients/finetuned_guard.py | |
| from typing import Generator, Any, Dict, Optional | |
| import json | |
| from .base import LlmClient | |
| class FinetunedGuardClient(LlmClient): | |
| """LLM client for finetuned model for safe/unsafe classification using zazaman/fmb.""" | |
| def __init__(self, config_dict: Dict[str, Any], system_prompt: str, shared_components: Optional[Dict[str, Any]] = None): | |
| super().__init__(config_dict, system_prompt) | |
| # If shared components are provided, use them instead of loading our own | |
| if shared_components: | |
| print(f" π FinetunedGuardClient: Using shared model components") | |
| self.model = shared_components["model"] | |
| self.tokenizer = shared_components["tokenizer"] | |
| self.classifier = shared_components["classifier"] | |
| self.transformers_available = True | |
| return | |
| # Fallback: Load our own model (this should rarely happen now) | |
| print(f" β οΈ FinetunedGuardClient: Loading independent model (shared components not available)") | |
| try: | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
| import torch | |
| # Disable torch compilation globally | |
| torch._dynamo.config.suppress_errors = True | |
| torch._dynamo.config.disable = True | |
| self.transformers_available = True | |
| except ImportError: | |
| raise ImportError( | |
| "transformers library is required for FinetunedGuardClient. " | |
| "Install it with: pip install transformers torch" | |
| ) | |
| except AttributeError: | |
| # If torch._dynamo doesn't exist in older versions, that's fine | |
| self.transformers_available = True | |
| # Get model name from config or use default | |
| model_name = config_dict.get("model_name", "zazaman/fmb") | |
| print(f"π Loading finetuned model: {model_name}") | |
| try: | |
| # Disable torch compile optimizations for lightweight CPU-only devices | |
| import os | |
| os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
| os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
| # Disable TensorFlow oneDNN warnings | |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| device_map=None # Disable automatic device mapping | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Explicitly disable compilation on the model | |
| if hasattr(self.model, '_compiler_config'): | |
| self.model._compiler_config = None | |
| # Use CPU device for lightweight operation | |
| device = "cpu" | |
| self.model = self.model.to(device) | |
| self.classifier = pipeline( | |
| "text-classification", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=device, | |
| framework="pt", | |
| torch_dtype=torch.float32 | |
| ) | |
| print(f"β Finetuned Guard Client initialized successfully.") | |
| print(f" Model: {model_name}") | |
| print(f" Device: {device}") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load finetuned model {model_name}: {e}") | |
| def generate_content(self, prompt: str) -> str: | |
| """ | |
| Classifies the prompt as safe or unsafe using the finetuned model. | |
| Returns a JSON response compatible with the existing AI detection system. | |
| """ | |
| try: | |
| # Classify the prompt | |
| result = self.classifier(prompt)[0] | |
| # Extract the prediction and confidence | |
| predicted_label = result['label'] | |
| confidence_score = result['score'] | |
| # Determine safety based on the model's prediction | |
| # Assuming 'SAFE' and 'UNSAFE' are the labels from your fine-tuned model | |
| is_safe = predicted_label.upper() == 'SAFE' | |
| # Create response in the expected format | |
| response_data = { | |
| "safety_status": "safe" if is_safe else "unsafe", | |
| "attack_type": "none" if is_safe else "prompt_injection", | |
| "confidence": confidence_score, | |
| "is_safe": is_safe, | |
| "model_used": "zazaman/fmb", | |
| "reason": f"Model predicted '{predicted_label}' with {confidence_score:.2%} confidence" | |
| } | |
| return json.dumps(response_data) | |
| except Exception as e: | |
| # Return error response in JSON format | |
| error_response = { | |
| "safety_status": "error", | |
| "attack_type": "unknown", | |
| "confidence": 0.0, | |
| "is_safe": False, | |
| "model_used": "zazaman/fmb", | |
| "reason": f"Classification error: {str(e)}" | |
| } | |
| return json.dumps(error_response) | |
| def generate_content_stream(self, prompt: str) -> Generator[str, None, None]: | |
| """ | |
| Streaming is not applicable for classification tasks. | |
| Returns the classification result as a single chunk. | |
| """ | |
| yield self.generate_content(prompt) | |
| def _generate_content_impl(self, prompt: str) -> str: | |
| """Implementation for base class compatibility.""" | |
| return self.generate_content(prompt) | |
| def _generate_content_stream_impl(self, prompt: str) -> Generator[str, None, None]: | |
| """Implementation for base class compatibility.""" | |
| return self.generate_content_stream(prompt) |