Spaces:
Sleeping
Sleeping
| # guardrails/attachments/txt_guardrail.py | |
| import time | |
| import json | |
| from typing import Dict, Any, Tuple, List | |
| from .base import AttachmentGuardrail | |
| class TxtGuardrail(AttachmentGuardrail): | |
| """ | |
| Guardrail for text files (.txt, .md, etc.). | |
| Chunks text content and analyzes each chunk for prompt injection attacks. | |
| """ | |
| def __init__(self, config: Dict[str, Any]): | |
| super().__init__(config) | |
| self.chunk_size = config.get("chunk_size", 500) # tokens per chunk | |
| self.confidence_threshold = config.get("confidence_threshold", 0.75) | |
| self.max_file_size = config.get("max_file_size_mb", 10) * 1024 * 1024 # Convert MB to bytes | |
| # Initialize the finetuned model for analysis | |
| self.model_client = None | |
| self._init_model() | |
| def _init_model(self): | |
| """Initialize the finetuned model client for text analysis (using shared model)""" | |
| try: | |
| from llm_clients.shared_models import shared_model_manager | |
| self.model_client = shared_model_manager.get_finetuned_guard_client("zazaman/fmb") | |
| if self.model_client: | |
| print(f" 🔍 TXT Guardrail: Using shared model zazaman/fmb") | |
| else: | |
| print(f" ⚠️ TXT Guardrail: Could not get shared model") | |
| except Exception as e: | |
| print(f" ⚠️ TXT Guardrail: Could not initialize shared model: {e}") | |
| self.model_client = None | |
| def get_supported_extensions(self) -> List[str]: | |
| """Return supported text file extensions""" | |
| return ['.txt', '.md', '.text', '.rtf'] | |
| def process_file(self, file_path: str, file_content: bytes) -> Tuple[bool, Dict[str, Any]]: | |
| """ | |
| Process a text file by chunking and analyzing each chunk for threats. | |
| Args: | |
| file_path: Path/name of the uploaded file | |
| file_content: Raw bytes content of the file | |
| Returns: | |
| Tuple of (is_safe, analysis_details) | |
| """ | |
| start_time = time.time() | |
| # Get basic file info | |
| file_info = self.get_file_info(file_path, file_content) | |
| analysis_details = { | |
| **file_info, | |
| "chunk_size": self.chunk_size, | |
| "confidence_threshold": self.confidence_threshold, | |
| "chunks_analyzed": 0, | |
| "chunks_unsafe": 0, | |
| "max_confidence": 0.0, | |
| "analysis_time_ms": 0, | |
| "chunks_details": [], | |
| "model_used": "zazaman/fmb" | |
| } | |
| try: | |
| # Check file size | |
| if len(file_content) > self.max_file_size: | |
| analysis_details["error"] = f"File too large: {file_info['size_kb']}KB > {self.max_file_size/1024/1024}MB" | |
| return False, analysis_details | |
| # Check if model is available | |
| if not self.model_client: | |
| analysis_details["error"] = "Text analysis model not available" | |
| return False, analysis_details | |
| # Decode text content | |
| try: | |
| text_content = file_content.decode('utf-8') | |
| except UnicodeDecodeError: | |
| try: | |
| text_content = file_content.decode('latin-1') | |
| except UnicodeDecodeError: | |
| analysis_details["error"] = "Could not decode text file. Unsupported encoding." | |
| return False, analysis_details | |
| # Chunk the text | |
| chunks = self._chunk_text(text_content) | |
| analysis_details["chunks_analyzed"] = len(chunks) | |
| if not chunks: | |
| analysis_details["warning"] = "Empty file or no processable content" | |
| return True, analysis_details | |
| # Analyze each chunk | |
| unsafe_chunks = 0 | |
| max_confidence = 0.0 | |
| for i, chunk in enumerate(chunks): | |
| chunk_start_time = time.time() | |
| try: | |
| # Analyze chunk with the finetuned model | |
| response = self.model_client.generate_content(chunk) | |
| # Parse the JSON response | |
| ai_result = json.loads(response) | |
| confidence = ai_result.get("confidence", 0.0) | |
| safety_status = ai_result.get("safety_status", "unsafe") | |
| attack_type = ai_result.get("attack_type", "unknown") | |
| is_chunk_safe = safety_status.lower() == "safe" | |
| chunk_latency = round((time.time() - chunk_start_time) * 1000, 1) | |
| chunk_detail = { | |
| "chunk_index": i, | |
| "chunk_length": len(chunk), | |
| "is_safe": is_chunk_safe, | |
| "confidence": confidence, | |
| "safety_status": safety_status, | |
| "attack_type": attack_type, | |
| "latency_ms": chunk_latency, | |
| "preview": chunk[:100] + "..." if len(chunk) > 100 else chunk | |
| } | |
| analysis_details["chunks_details"].append(chunk_detail) | |
| # Track statistics | |
| max_confidence = max(max_confidence, confidence) | |
| # Check if chunk is unsafe with high confidence | |
| if not is_chunk_safe and confidence > self.confidence_threshold: | |
| unsafe_chunks += 1 | |
| chunk_detail["flagged"] = True | |
| print(f" 🚨 TXT Guardrail: Unsafe chunk {i+1}/{len(chunks)} detected (confidence: {confidence:.3f})") | |
| except Exception as e: | |
| # If we can't analyze a chunk, treat it as unsafe | |
| chunk_detail = { | |
| "chunk_index": i, | |
| "chunk_length": len(chunk), | |
| "is_safe": False, | |
| "error": str(e), | |
| "latency_ms": round((time.time() - chunk_start_time) * 1000, 1), | |
| "preview": chunk[:100] + "..." if len(chunk) > 100 else chunk | |
| } | |
| analysis_details["chunks_details"].append(chunk_detail) | |
| unsafe_chunks += 1 | |
| analysis_details["chunks_unsafe"] = unsafe_chunks | |
| analysis_details["max_confidence"] = max_confidence | |
| analysis_details["analysis_time_ms"] = round((time.time() - start_time) * 1000, 1) | |
| # File is safe if no chunks were flagged as unsafe | |
| is_file_safe = unsafe_chunks == 0 | |
| if not is_file_safe: | |
| analysis_details["threat_summary"] = f"Detected {unsafe_chunks} unsafe chunks out of {len(chunks)} total chunks" | |
| return is_file_safe, analysis_details | |
| except Exception as e: | |
| analysis_details["error"] = f"Unexpected error during analysis: {str(e)}" | |
| analysis_details["analysis_time_ms"] = round((time.time() - start_time) * 1000, 1) | |
| return False, analysis_details | |
| def _chunk_text(self, text: str) -> List[str]: | |
| """ | |
| Chunk text into pieces of approximately chunk_size tokens. | |
| Uses a simple word-based approximation (1 token ≈ 0.75 words). | |
| """ | |
| if not text.strip(): | |
| return [] | |
| # Approximate tokens using word count (1 token ≈ 0.75 words) | |
| # So for 500 tokens, we want ~667 words | |
| words_per_chunk = int(self.chunk_size / 0.75) | |
| # Split text into words | |
| words = text.split() | |
| if len(words) <= words_per_chunk: | |
| # Text is small enough to be a single chunk | |
| return [text] | |
| chunks = [] | |
| current_chunk_words = [] | |
| for word in words: | |
| current_chunk_words.append(word) | |
| # If we've reached the target chunk size, create a chunk | |
| if len(current_chunk_words) >= words_per_chunk: | |
| chunk_text = ' '.join(current_chunk_words) | |
| chunks.append(chunk_text) | |
| current_chunk_words = [] | |
| # Add remaining words as the last chunk | |
| if current_chunk_words: | |
| chunk_text = ' '.join(current_chunk_words) | |
| chunks.append(chunk_text) | |
| return chunks | |
| def _estimate_tokens(self, text: str) -> int: | |
| """Estimate token count using word count approximation""" | |
| words = len(text.split()) | |
| return int(words * 0.75) # Rough approximation: 1 token ≈ 0.75 words |