Spaces:
Sleeping
Sleeping
| # src/utils.py | |
| import os | |
| import logging | |
| import tempfile | |
| import hashlib | |
| import time | |
| from typing import Tuple, Optional, Dict | |
| import streamlit as st | |
| from config import config | |
| from pydub import AudioSegment | |
| import threading | |
| import librosa | |
| import soundfile as sf | |
| # ========================= | |
| # Logging | |
| # ========================= | |
| def setup_production_logging(): | |
| """Setup production-grade logging for HF Spaces""" | |
| if config.ENABLE_LOGGING: | |
| log_format = "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s" | |
| file_handler = logging.FileHandler(os.path.join(tempfile.gettempdir(), "drug_detector.log")) | |
| file_handler.setLevel(logging.INFO) | |
| file_handler.setFormatter(logging.Formatter(log_format)) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setLevel(logging.WARNING) | |
| console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) | |
| logging.basicConfig(level=logging.INFO, handlers=[file_handler, console_handler]) | |
| logger = logging.getLogger(__name__) | |
| logger.info("Production logging initialized (HF Spaces)") | |
| return logger | |
| else: | |
| return logging.getLogger(__name__) | |
| logger = setup_production_logging() | |
| # ========================= | |
| # Rate-limiting (in-memory) | |
| # ========================= | |
| rate_limit_lock = threading.Lock() | |
| class SecurityManager: | |
| """Production security management (HF Spaces friendly)""" | |
| def __init__(self): | |
| self.request_counts: Dict[str, list] = {} # in-memory only | |
| def get_client_id(self) -> str: | |
| """Get client identifier for rate limiting""" | |
| try: | |
| if "client_id" not in st.session_state: | |
| st.session_state.client_id = hashlib.md5(str(time.time()).encode()).hexdigest()[:16] | |
| return st.session_state.client_id | |
| except Exception: | |
| # Fallback for contexts where session_state is unavailable | |
| return hashlib.md5(str(time.time()).encode()).hexdigest()[:16] | |
| def check_rate_limit(self) -> Tuple[bool, Optional[str]]: | |
| """Check if request is within rate limits""" | |
| try: | |
| client_id = self.get_client_id() | |
| current_time = time.time() | |
| with rate_limit_lock: | |
| if client_id not in self.request_counts: | |
| self.request_counts[client_id] = [] | |
| # Remove old requests outside window | |
| self.request_counts[client_id] = [ | |
| t for t in self.request_counts[client_id] | |
| if current_time - t < config.RATE_LIMIT_WINDOW | |
| ] | |
| if len(self.request_counts[client_id]) >= config.RATE_LIMIT_REQUESTS: | |
| logger.warning(f"Rate limit exceeded for client: {client_id}") | |
| return False, f"Rate limit exceeded. Max {config.RATE_LIMIT_REQUESTS} requests per hour." | |
| # Add current request | |
| self.request_counts[client_id].append(current_time) | |
| return True, None | |
| except Exception as e: | |
| logger.error(f"Rate limit check failed: {e}") | |
| return True, None # allow request on error | |
| # ========================= | |
| # File Handling | |
| # ========================= | |
| class FileManager: | |
| """Secure file handling for HF Spaces""" | |
| def create_secure_temp_file(uploaded_file) -> Optional[str]: | |
| """Create temporary file in ephemeral directory""" | |
| try: | |
| temp_dir = os.path.join(tempfile.gettempdir(), "drug_audio") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| file_hash = hashlib.md5(f"{uploaded_file.name}{time.time()}".encode()).hexdigest()[:12] | |
| file_ext = uploaded_file.name.split('.')[-1].lower() | |
| secure_filename = f"audio_{file_hash}.{file_ext}" | |
| temp_path = os.path.join(temp_dir, secure_filename) | |
| uploaded_file.seek(0) | |
| with open(temp_path, "wb") as f: | |
| f.write(uploaded_file.read()) | |
| logger.info(f"Created secure temp file: {secure_filename}") | |
| return temp_path | |
| except Exception as e: | |
| logger.error(f"Failed to create temp file: {e}") | |
| return None | |
| def cleanup_file(file_path: str, is_temp: bool = True): | |
| """Securely delete temp file. Sample files will never be deleted.""" | |
| try: | |
| if file_path and os.path.exists(file_path) and is_temp: | |
| os.unlink(file_path) | |
| logger.info(f"Deleted temp file: {os.path.basename(file_path)}") | |
| except Exception as e: | |
| logger.warning(f"Failed to delete file {file_path}: {e}") | |
| def is_sample_file(file_path: str) -> bool: | |
| """Check if file belongs to bundled sample directory""" | |
| if not file_path: | |
| return False | |
| return "audio_sample" in file_path | |
| # ========================= | |
| # File Validation | |
| # ========================= | |
| class AudioValidator: | |
| """Audio file validation with better HF Spaces compatibility""" | |
| def validate_file(uploaded_file) -> Tuple[bool, str]: | |
| try: | |
| # Size check | |
| size_mb = uploaded_file.size / (1024 * 1024) | |
| if size_mb > config.MAX_FILE_SIZE_MB: | |
| return False, f"File too large: {size_mb:.1f}MB (max {config.MAX_FILE_SIZE_MB}MB)" | |
| # Extension check | |
| ext = uploaded_file.name.split('.')[-1].lower() | |
| if ext not in config.ALLOWED_EXTENSIONS: | |
| return False, f"Unsupported file type: {ext}" | |
| # Filename check | |
| if any(c in uploaded_file.name for c in ['..', '/', '\\']): | |
| return False, "Invalid filename" | |
| # Create temporary file to test audio loading | |
| uploaded_file.seek(0) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}") as temp_file: | |
| temp_file.write(uploaded_file.read()) | |
| temp_path = temp_file.name | |
| try: | |
| # Use librosa for better audio format support | |
| y, sr = librosa.load(temp_path, sr=None) | |
| duration_sec = len(y) / sr | |
| if duration_sec > config.MAX_AUDIO_DURATION: | |
| return False, f"Audio too long: {duration_sec:.1f}s (max {config.MAX_AUDIO_DURATION}s)" | |
| # Check if audio has content | |
| if len(y) == 0 or max(abs(y)) < 1e-6: | |
| return False, "Audio file appears to be empty or silent" | |
| return True, f"Valid audio: {duration_sec:.1f}s, {sr}Hz" | |
| finally: | |
| # Clean up temp file | |
| try: | |
| os.unlink(temp_path) | |
| except: | |
| pass | |
| except Exception as e: | |
| logger.error(f"File validation error: {e}") | |
| return False, f"Validation failed: {str(e)}" | |
| def is_valid_audio(file_path) -> bool: | |
| """Check if the file is a valid audio using librosa""" | |
| try: | |
| y, sr = librosa.load(file_path, sr=None, duration=1.0) # Test first second | |
| return len(y) > 0 and sr > 0 | |
| except Exception as e: | |
| logger.error(f"Audio validation failed for {file_path}: {e}") | |
| return False | |
| # ========================= | |
| # Model Validation | |
| # ========================= | |
| class ModelManager: | |
| """Model validation for HF Spaces""" | |
| def validate_model_availability() -> Tuple[bool, str]: | |
| """Check that all required files exist in the model folder""" | |
| try: | |
| model_path = config.MODEL_PATH | |
| if not os.path.exists(model_path): | |
| return False, f"Model directory not found: {model_path}" | |
| required_files = [ | |
| "config.json", | |
| "model_config.json", | |
| "model_weights.json", | |
| "special_tokens_map.json", | |
| "tokenizer.json", | |
| "tokenizer_config.json", | |
| "vocab.txt" | |
| ] | |
| missing = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))] | |
| if missing: | |
| return False, f"Missing model files: {', '.join(missing)}" | |
| return True, "Model ready" | |
| except Exception as e: | |
| return False, f"Model validation failed: {str(e)}" | |
| def load_model(model_path: str): | |
| """Load model + tokenizer from your JSON/tokenizer files (HF compatible)""" | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from pathlib import Path | |
| try: | |
| model_path = Path(model_path) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| logger.info(f"Loaded model & tokenizer from {model_path}") | |
| return model, tokenizer | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| # ========================= | |
| # Global instances | |
| # ========================= | |
| security_manager = SecurityManager() | |
| file_manager = FileManager() | |
| model_manager = ModelManager() | |