audio-dashboard / src /utils.py
lawlevisan's picture
Create utils.py
ae1eb6c verified
# src/utils.py
import os
import logging
import tempfile
import hashlib
import time
from typing import Tuple, Optional, Dict
import streamlit as st
from config import config
from pydub import AudioSegment
import threading
import librosa
import soundfile as sf
# =========================
# Logging
# =========================
def setup_production_logging():
"""Setup production-grade logging for HF Spaces"""
if config.ENABLE_LOGGING:
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s"
file_handler = logging.FileHandler(os.path.join(tempfile.gettempdir(), "drug_detector.log"))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(log_format))
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.WARNING)
console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
logging.basicConfig(level=logging.INFO, handlers=[file_handler, console_handler])
logger = logging.getLogger(__name__)
logger.info("Production logging initialized (HF Spaces)")
return logger
else:
return logging.getLogger(__name__)
logger = setup_production_logging()
# =========================
# Rate-limiting (in-memory)
# =========================
rate_limit_lock = threading.Lock()
class SecurityManager:
"""Production security management (HF Spaces friendly)"""
def __init__(self):
self.request_counts: Dict[str, list] = {} # in-memory only
def get_client_id(self) -> str:
"""Get client identifier for rate limiting"""
try:
if "client_id" not in st.session_state:
st.session_state.client_id = hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
return st.session_state.client_id
except Exception:
# Fallback for contexts where session_state is unavailable
return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
def check_rate_limit(self) -> Tuple[bool, Optional[str]]:
"""Check if request is within rate limits"""
try:
client_id = self.get_client_id()
current_time = time.time()
with rate_limit_lock:
if client_id not in self.request_counts:
self.request_counts[client_id] = []
# Remove old requests outside window
self.request_counts[client_id] = [
t for t in self.request_counts[client_id]
if current_time - t < config.RATE_LIMIT_WINDOW
]
if len(self.request_counts[client_id]) >= config.RATE_LIMIT_REQUESTS:
logger.warning(f"Rate limit exceeded for client: {client_id}")
return False, f"Rate limit exceeded. Max {config.RATE_LIMIT_REQUESTS} requests per hour."
# Add current request
self.request_counts[client_id].append(current_time)
return True, None
except Exception as e:
logger.error(f"Rate limit check failed: {e}")
return True, None # allow request on error
# =========================
# File Handling
# =========================
class FileManager:
"""Secure file handling for HF Spaces"""
@staticmethod
def create_secure_temp_file(uploaded_file) -> Optional[str]:
"""Create temporary file in ephemeral directory"""
try:
temp_dir = os.path.join(tempfile.gettempdir(), "drug_audio")
os.makedirs(temp_dir, exist_ok=True)
file_hash = hashlib.md5(f"{uploaded_file.name}{time.time()}".encode()).hexdigest()[:12]
file_ext = uploaded_file.name.split('.')[-1].lower()
secure_filename = f"audio_{file_hash}.{file_ext}"
temp_path = os.path.join(temp_dir, secure_filename)
uploaded_file.seek(0)
with open(temp_path, "wb") as f:
f.write(uploaded_file.read())
logger.info(f"Created secure temp file: {secure_filename}")
return temp_path
except Exception as e:
logger.error(f"Failed to create temp file: {e}")
return None
@staticmethod
def cleanup_file(file_path: str, is_temp: bool = True):
"""Securely delete temp file. Sample files will never be deleted."""
try:
if file_path and os.path.exists(file_path) and is_temp:
os.unlink(file_path)
logger.info(f"Deleted temp file: {os.path.basename(file_path)}")
except Exception as e:
logger.warning(f"Failed to delete file {file_path}: {e}")
@staticmethod
def is_sample_file(file_path: str) -> bool:
"""Check if file belongs to bundled sample directory"""
if not file_path:
return False
return "audio_sample" in file_path
# =========================
# File Validation
# =========================
class AudioValidator:
"""Audio file validation with better HF Spaces compatibility"""
@staticmethod
def validate_file(uploaded_file) -> Tuple[bool, str]:
try:
# Size check
size_mb = uploaded_file.size / (1024 * 1024)
if size_mb > config.MAX_FILE_SIZE_MB:
return False, f"File too large: {size_mb:.1f}MB (max {config.MAX_FILE_SIZE_MB}MB)"
# Extension check
ext = uploaded_file.name.split('.')[-1].lower()
if ext not in config.ALLOWED_EXTENSIONS:
return False, f"Unsupported file type: {ext}"
# Filename check
if any(c in uploaded_file.name for c in ['..', '/', '\\']):
return False, "Invalid filename"
# Create temporary file to test audio loading
uploaded_file.seek(0)
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}") as temp_file:
temp_file.write(uploaded_file.read())
temp_path = temp_file.name
try:
# Use librosa for better audio format support
y, sr = librosa.load(temp_path, sr=None)
duration_sec = len(y) / sr
if duration_sec > config.MAX_AUDIO_DURATION:
return False, f"Audio too long: {duration_sec:.1f}s (max {config.MAX_AUDIO_DURATION}s)"
# Check if audio has content
if len(y) == 0 or max(abs(y)) < 1e-6:
return False, "Audio file appears to be empty or silent"
return True, f"Valid audio: {duration_sec:.1f}s, {sr}Hz"
finally:
# Clean up temp file
try:
os.unlink(temp_path)
except:
pass
except Exception as e:
logger.error(f"File validation error: {e}")
return False, f"Validation failed: {str(e)}"
def is_valid_audio(file_path) -> bool:
"""Check if the file is a valid audio using librosa"""
try:
y, sr = librosa.load(file_path, sr=None, duration=1.0) # Test first second
return len(y) > 0 and sr > 0
except Exception as e:
logger.error(f"Audio validation failed for {file_path}: {e}")
return False
# =========================
# Model Validation
# =========================
class ModelManager:
"""Model validation for HF Spaces"""
@staticmethod
def validate_model_availability() -> Tuple[bool, str]:
"""Check that all required files exist in the model folder"""
try:
model_path = config.MODEL_PATH
if not os.path.exists(model_path):
return False, f"Model directory not found: {model_path}"
required_files = [
"config.json",
"model_config.json",
"model_weights.json",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
"vocab.txt"
]
missing = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
if missing:
return False, f"Missing model files: {', '.join(missing)}"
return True, "Model ready"
except Exception as e:
return False, f"Model validation failed: {str(e)}"
@staticmethod
def load_model(model_path: str):
"""Load model + tokenizer from your JSON/tokenizer files (HF compatible)"""
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
try:
model_path = Path(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
logger.info(f"Loaded model & tokenizer from {model_path}")
return model, tokenizer
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
# =========================
# Global instances
# =========================
security_manager = SecurityManager()
file_manager = FileManager()
model_manager = ModelManager()