Spaces:
Sleeping
Sleeping
File size: 9,559 Bytes
ae1eb6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | # src/utils.py
import os
import logging
import tempfile
import hashlib
import time
from typing import Tuple, Optional, Dict
import streamlit as st
from config import config
from pydub import AudioSegment
import threading
import librosa
import soundfile as sf
# =========================
# Logging
# =========================
def setup_production_logging():
"""Setup production-grade logging for HF Spaces"""
if config.ENABLE_LOGGING:
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s"
file_handler = logging.FileHandler(os.path.join(tempfile.gettempdir(), "drug_detector.log"))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(log_format))
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.WARNING)
console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
logging.basicConfig(level=logging.INFO, handlers=[file_handler, console_handler])
logger = logging.getLogger(__name__)
logger.info("Production logging initialized (HF Spaces)")
return logger
else:
return logging.getLogger(__name__)
logger = setup_production_logging()
# =========================
# Rate-limiting (in-memory)
# =========================
rate_limit_lock = threading.Lock()
class SecurityManager:
"""Production security management (HF Spaces friendly)"""
def __init__(self):
self.request_counts: Dict[str, list] = {} # in-memory only
def get_client_id(self) -> str:
"""Get client identifier for rate limiting"""
try:
if "client_id" not in st.session_state:
st.session_state.client_id = hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
return st.session_state.client_id
except Exception:
# Fallback for contexts where session_state is unavailable
return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
def check_rate_limit(self) -> Tuple[bool, Optional[str]]:
"""Check if request is within rate limits"""
try:
client_id = self.get_client_id()
current_time = time.time()
with rate_limit_lock:
if client_id not in self.request_counts:
self.request_counts[client_id] = []
# Remove old requests outside window
self.request_counts[client_id] = [
t for t in self.request_counts[client_id]
if current_time - t < config.RATE_LIMIT_WINDOW
]
if len(self.request_counts[client_id]) >= config.RATE_LIMIT_REQUESTS:
logger.warning(f"Rate limit exceeded for client: {client_id}")
return False, f"Rate limit exceeded. Max {config.RATE_LIMIT_REQUESTS} requests per hour."
# Add current request
self.request_counts[client_id].append(current_time)
return True, None
except Exception as e:
logger.error(f"Rate limit check failed: {e}")
return True, None # allow request on error
# =========================
# File Handling
# =========================
class FileManager:
"""Secure file handling for HF Spaces"""
@staticmethod
def create_secure_temp_file(uploaded_file) -> Optional[str]:
"""Create temporary file in ephemeral directory"""
try:
temp_dir = os.path.join(tempfile.gettempdir(), "drug_audio")
os.makedirs(temp_dir, exist_ok=True)
file_hash = hashlib.md5(f"{uploaded_file.name}{time.time()}".encode()).hexdigest()[:12]
file_ext = uploaded_file.name.split('.')[-1].lower()
secure_filename = f"audio_{file_hash}.{file_ext}"
temp_path = os.path.join(temp_dir, secure_filename)
uploaded_file.seek(0)
with open(temp_path, "wb") as f:
f.write(uploaded_file.read())
logger.info(f"Created secure temp file: {secure_filename}")
return temp_path
except Exception as e:
logger.error(f"Failed to create temp file: {e}")
return None
@staticmethod
def cleanup_file(file_path: str, is_temp: bool = True):
"""Securely delete temp file. Sample files will never be deleted."""
try:
if file_path and os.path.exists(file_path) and is_temp:
os.unlink(file_path)
logger.info(f"Deleted temp file: {os.path.basename(file_path)}")
except Exception as e:
logger.warning(f"Failed to delete file {file_path}: {e}")
@staticmethod
def is_sample_file(file_path: str) -> bool:
"""Check if file belongs to bundled sample directory"""
if not file_path:
return False
return "audio_sample" in file_path
# =========================
# File Validation
# =========================
class AudioValidator:
"""Audio file validation with better HF Spaces compatibility"""
@staticmethod
def validate_file(uploaded_file) -> Tuple[bool, str]:
try:
# Size check
size_mb = uploaded_file.size / (1024 * 1024)
if size_mb > config.MAX_FILE_SIZE_MB:
return False, f"File too large: {size_mb:.1f}MB (max {config.MAX_FILE_SIZE_MB}MB)"
# Extension check
ext = uploaded_file.name.split('.')[-1].lower()
if ext not in config.ALLOWED_EXTENSIONS:
return False, f"Unsupported file type: {ext}"
# Filename check
if any(c in uploaded_file.name for c in ['..', '/', '\\']):
return False, "Invalid filename"
# Create temporary file to test audio loading
uploaded_file.seek(0)
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}") as temp_file:
temp_file.write(uploaded_file.read())
temp_path = temp_file.name
try:
# Use librosa for better audio format support
y, sr = librosa.load(temp_path, sr=None)
duration_sec = len(y) / sr
if duration_sec > config.MAX_AUDIO_DURATION:
return False, f"Audio too long: {duration_sec:.1f}s (max {config.MAX_AUDIO_DURATION}s)"
# Check if audio has content
if len(y) == 0 or max(abs(y)) < 1e-6:
return False, "Audio file appears to be empty or silent"
return True, f"Valid audio: {duration_sec:.1f}s, {sr}Hz"
finally:
# Clean up temp file
try:
os.unlink(temp_path)
except:
pass
except Exception as e:
logger.error(f"File validation error: {e}")
return False, f"Validation failed: {str(e)}"
def is_valid_audio(file_path) -> bool:
"""Check if the file is a valid audio using librosa"""
try:
y, sr = librosa.load(file_path, sr=None, duration=1.0) # Test first second
return len(y) > 0 and sr > 0
except Exception as e:
logger.error(f"Audio validation failed for {file_path}: {e}")
return False
# =========================
# Model Validation
# =========================
class ModelManager:
"""Model validation for HF Spaces"""
@staticmethod
def validate_model_availability() -> Tuple[bool, str]:
"""Check that all required files exist in the model folder"""
try:
model_path = config.MODEL_PATH
if not os.path.exists(model_path):
return False, f"Model directory not found: {model_path}"
required_files = [
"config.json",
"model_config.json",
"model_weights.json",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
"vocab.txt"
]
missing = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
if missing:
return False, f"Missing model files: {', '.join(missing)}"
return True, "Model ready"
except Exception as e:
return False, f"Model validation failed: {str(e)}"
@staticmethod
def load_model(model_path: str):
"""Load model + tokenizer from your JSON/tokenizer files (HF compatible)"""
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
try:
model_path = Path(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
logger.info(f"Loaded model & tokenizer from {model_path}")
return model, tokenizer
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
# =========================
# Global instances
# =========================
security_manager = SecurityManager()
file_manager = FileManager()
model_manager = ModelManager()
|