File size: 9,559 Bytes
ae1eb6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# src/utils.py
import os
import logging
import tempfile
import hashlib
import time
from typing import Tuple, Optional, Dict
import streamlit as st
from config import config
from pydub import AudioSegment
import threading
import librosa
import soundfile as sf
# =========================
# Logging
# =========================
def setup_production_logging():
    """Setup production-grade logging for HF Spaces"""
    if config.ENABLE_LOGGING:
        log_format = "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s"
        
        file_handler = logging.FileHandler(os.path.join(tempfile.gettempdir(), "drug_detector.log"))
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(logging.Formatter(log_format))
        
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.WARNING)
        console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
        
        logging.basicConfig(level=logging.INFO, handlers=[file_handler, console_handler])
        logger = logging.getLogger(__name__)
        logger.info("Production logging initialized (HF Spaces)")
        return logger
    else:
        return logging.getLogger(__name__)

logger = setup_production_logging()

# =========================
# Rate-limiting (in-memory)
# =========================
rate_limit_lock = threading.Lock()

class SecurityManager:
    """Production security management (HF Spaces friendly)"""
    def __init__(self):
        self.request_counts: Dict[str, list] = {}  # in-memory only
    
    def get_client_id(self) -> str:
        """Get client identifier for rate limiting"""
        try:
            if "client_id" not in st.session_state:
                st.session_state.client_id = hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
            return st.session_state.client_id
        except Exception:
            # Fallback for contexts where session_state is unavailable
            return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
    
    def check_rate_limit(self) -> Tuple[bool, Optional[str]]:
        """Check if request is within rate limits"""
        try:
            client_id = self.get_client_id()
            current_time = time.time()
            
            with rate_limit_lock:
                if client_id not in self.request_counts:
                    self.request_counts[client_id] = []
                
                # Remove old requests outside window
                self.request_counts[client_id] = [
                    t for t in self.request_counts[client_id] 
                    if current_time - t < config.RATE_LIMIT_WINDOW
                ]
                
                if len(self.request_counts[client_id]) >= config.RATE_LIMIT_REQUESTS:
                    logger.warning(f"Rate limit exceeded for client: {client_id}")
                    return False, f"Rate limit exceeded. Max {config.RATE_LIMIT_REQUESTS} requests per hour."
                
                # Add current request
                self.request_counts[client_id].append(current_time)
            
            return True, None
        except Exception as e:
            logger.error(f"Rate limit check failed: {e}")
            return True, None  # allow request on error

# =========================
# File Handling
# =========================
class FileManager:
    """Secure file handling for HF Spaces"""

    @staticmethod
    def create_secure_temp_file(uploaded_file) -> Optional[str]:
        """Create temporary file in ephemeral directory"""
        try:
            temp_dir = os.path.join(tempfile.gettempdir(), "drug_audio")
            os.makedirs(temp_dir, exist_ok=True)
            
            file_hash = hashlib.md5(f"{uploaded_file.name}{time.time()}".encode()).hexdigest()[:12]
            file_ext = uploaded_file.name.split('.')[-1].lower()
            secure_filename = f"audio_{file_hash}.{file_ext}"
            
            temp_path = os.path.join(temp_dir, secure_filename)
            
            uploaded_file.seek(0)
            with open(temp_path, "wb") as f:
                f.write(uploaded_file.read())
            
            logger.info(f"Created secure temp file: {secure_filename}")
            return temp_path
        except Exception as e:
            logger.error(f"Failed to create temp file: {e}")
            return None
    
    @staticmethod
    def cleanup_file(file_path: str, is_temp: bool = True):
        """Securely delete temp file. Sample files will never be deleted."""
        try:
            if file_path and os.path.exists(file_path) and is_temp:
                os.unlink(file_path)
                logger.info(f"Deleted temp file: {os.path.basename(file_path)}")
        except Exception as e:
            logger.warning(f"Failed to delete file {file_path}: {e}")

    @staticmethod
    def is_sample_file(file_path: str) -> bool:
        """Check if file belongs to bundled sample directory"""
        if not file_path:
            return False
        return "audio_sample" in file_path

# =========================
# File Validation
# =========================
class AudioValidator:
    """Audio file validation with better HF Spaces compatibility"""
    
    @staticmethod
    def validate_file(uploaded_file) -> Tuple[bool, str]:
        try:
            # Size check
            size_mb = uploaded_file.size / (1024 * 1024)
            if size_mb > config.MAX_FILE_SIZE_MB:
                return False, f"File too large: {size_mb:.1f}MB (max {config.MAX_FILE_SIZE_MB}MB)"
            
            # Extension check
            ext = uploaded_file.name.split('.')[-1].lower()
            if ext not in config.ALLOWED_EXTENSIONS:
                return False, f"Unsupported file type: {ext}"
            
            # Filename check
            if any(c in uploaded_file.name for c in ['..', '/', '\\']):
                return False, "Invalid filename"
            
            # Create temporary file to test audio loading
            uploaded_file.seek(0)
            with tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}") as temp_file:
                temp_file.write(uploaded_file.read())
                temp_path = temp_file.name
            
            try:
                # Use librosa for better audio format support
                y, sr = librosa.load(temp_path, sr=None)
                duration_sec = len(y) / sr
                
                if duration_sec > config.MAX_AUDIO_DURATION:
                    return False, f"Audio too long: {duration_sec:.1f}s (max {config.MAX_AUDIO_DURATION}s)"
                
                # Check if audio has content
                if len(y) == 0 or max(abs(y)) < 1e-6:
                    return False, "Audio file appears to be empty or silent"
                    
                return True, f"Valid audio: {duration_sec:.1f}s, {sr}Hz"
                
            finally:
                # Clean up temp file
                try:
                    os.unlink(temp_path)
                except:
                    pass
                    
        except Exception as e:
            logger.error(f"File validation error: {e}")
            return False, f"Validation failed: {str(e)}"

def is_valid_audio(file_path) -> bool:
    """Check if the file is a valid audio using librosa"""
    try:
        y, sr = librosa.load(file_path, sr=None, duration=1.0)  # Test first second
        return len(y) > 0 and sr > 0
    except Exception as e:
        logger.error(f"Audio validation failed for {file_path}: {e}")
        return False

# =========================
# Model Validation
# =========================
class ModelManager:
    """Model validation for HF Spaces"""
    @staticmethod
    def validate_model_availability() -> Tuple[bool, str]:
        """Check that all required files exist in the model folder"""
        try:
            model_path = config.MODEL_PATH
            if not os.path.exists(model_path):
                return False, f"Model directory not found: {model_path}"

            required_files = [
                "config.json",
                "model_config.json",
                "model_weights.json",
                "special_tokens_map.json",
                "tokenizer.json",
                "tokenizer_config.json",
                "vocab.txt"
            ]
            missing = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
            if missing:
                return False, f"Missing model files: {', '.join(missing)}"

            return True, "Model ready"
        except Exception as e:
            return False, f"Model validation failed: {str(e)}"

    @staticmethod
    def load_model(model_path: str):
        """Load model + tokenizer from your JSON/tokenizer files (HF compatible)"""
        from transformers import AutoTokenizer, AutoModelForSequenceClassification
        from pathlib import Path

        try:
            model_path = Path(model_path)
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            model = AutoModelForSequenceClassification.from_pretrained(model_path)
            logger.info(f"Loaded model & tokenizer from {model_path}")
            return model, tokenizer
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise

# =========================
# Global instances
# =========================
security_manager = SecurityManager()
file_manager = FileManager()
model_manager = ModelManager()