import torch, io, torchaudio import numpy as np from pipeline import BioacousticEngine engine = None def get_engine(): global engine if engine is None: engine = BioacousticEngine() return engine def predict_bird(audio_bytes, distance_threshold=0.8): try: ae = get_engine() waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes)) processed_wave = ae.process_waveform(waveform, sample_rate) with torch.no_grad(): mel_spec = ae.preprocessor.process(processed_wave).unsqueeze(0) if mel_spec.shape[-1] >= 184: mel_spec = mel_spec[:, :, :, :184] else: mel_spec = torch.nn.functional.pad(mel_spec, (0, 184 - mel_spec.shape[-1])) embedding = ae.model(mel_spec).cpu().numpy().flatten().reshape(1, -1) coords = ae.reducer.transform(embedding) x, y = coords[0][0], coords[0][1] valid_df = ae.df[ae.df['Cluster_ID'] != -1].copy() distances = np.sqrt((valid_df['UMAP_X'] - x)**2 + (valid_df['UMAP_Y'] - y)**2) min_dist = distances.min() if min_dist > distance_threshold: return {"status": "NO_BIRD_DETECTED", "x": float(x), "y": float(y), "distance": float(min_dist)} match_row = valid_df.loc[distances.idxmin()] confidence = max(0.0, (distance_threshold - min_dist) / distance_threshold) * 100 return {"status": "SUCCESS", "species": match_row['Target_Bird'].upper(), "confidence": round(confidence, 2), "x": float(x), "y": float(y), "anchor_id": match_row['File_ID']} except Exception as e: return {"status": "ERROR", "message": str(e)}