# src/utils.py import os import logging import tempfile import hashlib import time from typing import Tuple, Optional, Dict import streamlit as st from config import config from pydub import AudioSegment import threading import librosa import soundfile as sf # ========================= # Logging # ========================= def setup_production_logging(): """Setup production-grade logging for HF Spaces""" if config.ENABLE_LOGGING: log_format = "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s" file_handler = logging.FileHandler(os.path.join(tempfile.gettempdir(), "drug_detector.log")) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(log_format)) console_handler = logging.StreamHandler() console_handler.setLevel(logging.WARNING) console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) logging.basicConfig(level=logging.INFO, handlers=[file_handler, console_handler]) logger = logging.getLogger(__name__) logger.info("Production logging initialized (HF Spaces)") return logger else: return logging.getLogger(__name__) logger = setup_production_logging() # ========================= # Rate-limiting (in-memory) # ========================= rate_limit_lock = threading.Lock() class SecurityManager: """Production security management (HF Spaces friendly)""" def __init__(self): self.request_counts: Dict[str, list] = {} # in-memory only def get_client_id(self) -> str: """Get client identifier for rate limiting""" try: if "client_id" not in st.session_state: st.session_state.client_id = hashlib.md5(str(time.time()).encode()).hexdigest()[:16] return st.session_state.client_id except Exception: # Fallback for contexts where session_state is unavailable return hashlib.md5(str(time.time()).encode()).hexdigest()[:16] def check_rate_limit(self) -> Tuple[bool, Optional[str]]: """Check if request is within rate limits""" try: client_id = self.get_client_id() current_time = time.time() with rate_limit_lock: if client_id not in self.request_counts: self.request_counts[client_id] = [] # Remove old requests outside window self.request_counts[client_id] = [ t for t in self.request_counts[client_id] if current_time - t < config.RATE_LIMIT_WINDOW ] if len(self.request_counts[client_id]) >= config.RATE_LIMIT_REQUESTS: logger.warning(f"Rate limit exceeded for client: {client_id}") return False, f"Rate limit exceeded. Max {config.RATE_LIMIT_REQUESTS} requests per hour." # Add current request self.request_counts[client_id].append(current_time) return True, None except Exception as e: logger.error(f"Rate limit check failed: {e}") return True, None # allow request on error # ========================= # File Handling # ========================= class FileManager: """Secure file handling for HF Spaces""" @staticmethod def create_secure_temp_file(uploaded_file) -> Optional[str]: """Create temporary file in ephemeral directory""" try: temp_dir = os.path.join(tempfile.gettempdir(), "drug_audio") os.makedirs(temp_dir, exist_ok=True) file_hash = hashlib.md5(f"{uploaded_file.name}{time.time()}".encode()).hexdigest()[:12] file_ext = uploaded_file.name.split('.')[-1].lower() secure_filename = f"audio_{file_hash}.{file_ext}" temp_path = os.path.join(temp_dir, secure_filename) uploaded_file.seek(0) with open(temp_path, "wb") as f: f.write(uploaded_file.read()) logger.info(f"Created secure temp file: {secure_filename}") return temp_path except Exception as e: logger.error(f"Failed to create temp file: {e}") return None @staticmethod def cleanup_file(file_path: str, is_temp: bool = True): """Securely delete temp file. Sample files will never be deleted.""" try: if file_path and os.path.exists(file_path) and is_temp: os.unlink(file_path) logger.info(f"Deleted temp file: {os.path.basename(file_path)}") except Exception as e: logger.warning(f"Failed to delete file {file_path}: {e}") @staticmethod def is_sample_file(file_path: str) -> bool: """Check if file belongs to bundled sample directory""" if not file_path: return False return "audio_sample" in file_path # ========================= # File Validation # ========================= class AudioValidator: """Audio file validation with better HF Spaces compatibility""" @staticmethod def validate_file(uploaded_file) -> Tuple[bool, str]: try: # Size check size_mb = uploaded_file.size / (1024 * 1024) if size_mb > config.MAX_FILE_SIZE_MB: return False, f"File too large: {size_mb:.1f}MB (max {config.MAX_FILE_SIZE_MB}MB)" # Extension check ext = uploaded_file.name.split('.')[-1].lower() if ext not in config.ALLOWED_EXTENSIONS: return False, f"Unsupported file type: {ext}" # Filename check if any(c in uploaded_file.name for c in ['..', '/', '\\']): return False, "Invalid filename" # Create temporary file to test audio loading uploaded_file.seek(0) with tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}") as temp_file: temp_file.write(uploaded_file.read()) temp_path = temp_file.name try: # Use librosa for better audio format support y, sr = librosa.load(temp_path, sr=None) duration_sec = len(y) / sr if duration_sec > config.MAX_AUDIO_DURATION: return False, f"Audio too long: {duration_sec:.1f}s (max {config.MAX_AUDIO_DURATION}s)" # Check if audio has content if len(y) == 0 or max(abs(y)) < 1e-6: return False, "Audio file appears to be empty or silent" return True, f"Valid audio: {duration_sec:.1f}s, {sr}Hz" finally: # Clean up temp file try: os.unlink(temp_path) except: pass except Exception as e: logger.error(f"File validation error: {e}") return False, f"Validation failed: {str(e)}" def is_valid_audio(file_path) -> bool: """Check if the file is a valid audio using librosa""" try: y, sr = librosa.load(file_path, sr=None, duration=1.0) # Test first second return len(y) > 0 and sr > 0 except Exception as e: logger.error(f"Audio validation failed for {file_path}: {e}") return False # ========================= # Model Validation # ========================= class ModelManager: """Model validation for HF Spaces""" @staticmethod def validate_model_availability() -> Tuple[bool, str]: """Check that all required files exist in the model folder""" try: model_path = config.MODEL_PATH if not os.path.exists(model_path): return False, f"Model directory not found: {model_path}" required_files = [ "config.json", "model_config.json", "model_weights.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.txt" ] missing = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))] if missing: return False, f"Missing model files: {', '.join(missing)}" return True, "Model ready" except Exception as e: return False, f"Model validation failed: {str(e)}" @staticmethod def load_model(model_path: str): """Load model + tokenizer from your JSON/tokenizer files (HF compatible)""" from transformers import AutoTokenizer, AutoModelForSequenceClassification from pathlib import Path try: model_path = Path(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained(model_path) logger.info(f"Loaded model & tokenizer from {model_path}") return model, tokenizer except Exception as e: logger.error(f"Failed to load model: {e}") raise # ========================= # Global instances # ========================= security_manager = SecurityManager() file_manager = FileManager() model_manager = ModelManager()