ai-mitra's picture
Upload inference.py with huggingface_hub
14cae56 verified
"""
Standalone FastAPI Heart Sound Analysis Server with Comprehensive Analysis
Requires only: inference.py + heartbeat-anomaly-detector-model.pt + audio files
Install dependencies:
pip install fastapi uvicorn torch librosa numpy scipy matplotlib peakutils
Run server:
uvicorn inference:app --host 0.0.0.0 --port 8000
API Usage:
GET /hb?filename=[PATH]/Heart_Failure_Sound.mp3
"""
import os
import io
import base64
from datetime import datetime
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import librosa
from scipy import signal
from scipy.signal import medfilt, savgol_filter, find_peaks
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# Try to import BioSPPy (optional but recommended)
try:
from biosppy.signals import pcg
BIOSPPY_AVAILABLE = True
except:
BIOSPPY_AVAILABLE = False
print("Warning: BioSPPy not available. Using fallback peak detection.")
app = FastAPI(title="Heart Sound Analysis API", version="2.0.0")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# ============================================================================
# CONFIGURATION
# ============================================================================
MODEL_PATH = "heartbeat-anomaly-detector-model.pt"
CLASS_NAMES = ['artifact', 'extrahs', 'murmur', 'normal']
SAMPLE_RATE = 2000
AUDIO_LENGTH = 30 # seconds
# Training baselines for enhanced detection
TRAINING_BASELINES = {
'normal': {
'avg_heart_rate': {'mean': 104.5, 'std': 7.6},
'irregularity_score': {'mean': 0.317, 'std': 0.074},
's1_s2_amp_ratio': {'mean': 1.171, 'std': 0.244},
},
'murmur': {
'avg_heart_rate': {'mean': 87.5, 'std': 27.5},
'irregularity_score': {'mean': 0.185, 'std': 0.150},
's1_s2_amp_ratio': {'mean': 0.733, 'std': 0.092},
},
'extrahs': {
'avg_heart_rate': {'mean': 82.3, 'std': 9.2},
'irregularity_score': {'mean': 0.570, 'std': 0.068},
's1_s2_amp_ratio': {'mean': 0.836, 'std': 0.151},
},
}
# ============================================================================
# AUDIO PROCESSING FUNCTIONS
# ============================================================================
def load_audio_file(file_path):
"""Load audio file"""
try:
audio_data, sample_rate = librosa.load(file_path, sr=SAMPLE_RATE, duration=AUDIO_LENGTH, offset=0.0)
return audio_data, sample_rate
except Exception as e:
print(f"Error loading audio: {e}")
try:
import soundfile as sf
audio_data, sample_rate = sf.read(file_path)
if sample_rate != SAMPLE_RATE:
from scipy.signal import resample
audio_data = resample(audio_data, int(len(audio_data) * SAMPLE_RATE / sample_rate))
sample_rate = SAMPLE_RATE
if len(audio_data) > AUDIO_LENGTH * sample_rate:
audio_data = audio_data[:AUDIO_LENGTH * sample_rate]
return audio_data, sample_rate
except Exception as e2:
return None, None
def enhanced_preprocess_signal(audio_data, sample_rate):
"""Enhanced preprocessing"""
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
audio_data = audio_data.astype(np.float64)
audio_data = audio_data - np.mean(audio_data)
kernel_size = max(3, int(sample_rate * 0.01))
if kernel_size % 2 == 0:
kernel_size += 1
audio_data = medfilt(audio_data, kernel_size=kernel_size)
nyquist = sample_rate / 2
low_primary = 20 / nyquist
high_primary = 200 / nyquist
if high_primary < 1.0:
b_primary, a_primary = signal.butter(6, [low_primary, high_primary], btype='band')
audio_data = signal.filtfilt(b_primary, a_primary, audio_data)
if sample_rate >= 200:
for freq in [50, 60]:
if freq < nyquist * 0.9:
Q = 30
w0 = freq / nyquist
b_notch, a_notch = signal.iirnotch(w0, Q)
audio_data = signal.filtfilt(b_notch, a_notch, audio_data)
analytic_signal = signal.hilbert(audio_data)
envelope = np.abs(analytic_signal)
envelope_smooth_size = max(3, int(sample_rate * 0.005))
if envelope_smooth_size % 2 == 0:
envelope_smooth_size += 1
envelope = medfilt(envelope, kernel_size=envelope_smooth_size)
audio_data = 0.7 * audio_data + 0.3 * (envelope - np.mean(envelope))
rms = np.sqrt(np.mean(audio_data**2))
if rms > 0:
audio_data = audio_data / (rms * 3)
audio_data = np.clip(audio_data, -1.0, 1.0)
if sample_rate >= 400:
smooth_size = max(3, int(sample_rate * 0.002))
if smooth_size % 2 == 0:
smooth_size += 1
if smooth_size < len(audio_data):
try:
audio_data = savgol_filter(audio_data, smooth_size, 3)
except:
audio_data = np.convolve(audio_data, np.ones(smooth_size)/smooth_size, mode='same')
return audio_data
def analyze_heart_sound(audio_data, sample_rate):
"""Analyze heart sound using BioSPPy or fallback"""
processed_audio = enhanced_preprocess_signal(audio_data, sample_rate)
if BIOSPPY_AVAILABLE:
try:
with np.errstate(all='ignore'):
out = pcg.pcg(signal=processed_audio, sampling_rate=int(sample_rate), show=False)
filtered_signal = out['filtered']
# Extract heart rate
try:
heart_rate = np.array(out['inst_heart_rate'])
except:
try:
heart_rate = np.array([out['heart_rate']])
except:
heart_rate = []
# Extract peaks
all_peaks = np.array(out['peaks'])
s1_peaks = all_peaks[::2]
s2_peaks = all_peaks[1::2]
avg_heart_rate = np.mean(heart_rate) if len(heart_rate) > 0 else 0
return {
'filtered_signal': filtered_signal,
'heart_rate': heart_rate,
's1_peaks': s1_peaks,
's2_peaks': s2_peaks,
'avg_heart_rate': avg_heart_rate,
'num_s1': len(s1_peaks),
'num_s2': len(s2_peaks),
'sample_rate': sample_rate,
'method': 'BioSPPy'
}
except Exception as e:
print(f"BioSPPy failed: {e}")
# Fallback
peaks, _ = find_peaks(np.abs(processed_audio), height=0.2, distance=int(sample_rate * 0.6))
if len(peaks) > 1:
intervals = np.diff(peaks) / sample_rate
heart_rates = 60 / intervals
avg_heart_rate = np.mean(heart_rates)
amplitudes = np.abs(processed_audio[peaks])
median_amp = np.median(amplitudes)
s1_peaks = peaks[amplitudes >= median_amp]
s2_peaks = peaks[amplitudes < median_amp]
else:
heart_rates = []
avg_heart_rate = 0
s1_peaks = []
s2_peaks = []
return {
'filtered_signal': processed_audio,
'heart_rate': heart_rates if len(peaks) > 1 else [],
's1_peaks': s1_peaks,
's2_peaks': s2_peaks,
'avg_heart_rate': avg_heart_rate,
'num_s1': len(s1_peaks),
'num_s2': len(s2_peaks),
'sample_rate': sample_rate,
'method': 'Simple Peak Detection'
}
# ============================================================================
# RULE-BASED CONDITION DETECTION
# ============================================================================
def detect_arrhythmias(analysis_results):
"""Detect arrhythmias"""
arrhythmias = []
avg_hr = analysis_results['avg_heart_rate']
heart_rate_data = analysis_results['heart_rate']
method = analysis_results.get('method', 'Unknown')
if avg_hr > 100:
if avg_hr > 150:
arrhythmias.append({
"condition": "Severe Tachycardia",
"description": f"Heart rate significantly elevated (>150 BPM): {avg_hr:.1f} BPM",
"severity": "High",
"recommendation": "Immediate medical attention required",
"source": "Rule-Based Analysis"
})
else:
arrhythmias.append({
"condition": "Tachycardia",
"description": "Heart rate elevated (100-150 BPM)",
"severity": "Moderate",
"recommendation": "Consult cardiologist for evaluation",
"source": "Rule-Based Analysis"
})
elif avg_hr < 50:
if avg_hr < 40:
arrhythmias.append({
"condition": "Severe Bradycardia",
"description": "Heart rate significantly low (<40 BPM)",
"severity": "High",
"recommendation": "Emergency medical evaluation needed",
"source": "Rule-Based Analysis"
})
else:
arrhythmias.append({
"condition": "Bradycardia",
"description": "Heart rate below normal (40-50 BPM)",
"severity": "Moderate",
"recommendation": "Medical evaluation recommended",
"source": "Rule-Based Analysis"
})
if len(heart_rate_data) > 2:
irregularity_score = np.std(heart_rate_data) / np.mean(heart_rate_data) if np.mean(heart_rate_data) > 0 else 0
high_threshold = 0.15 if method == "BioSPPy" else 0.25
moderate_threshold = 0.10 if method == "BioSPPy" else 0.18
if irregularity_score > high_threshold:
arrhythmias.append({
"condition": "Irregular Rhythm (Possible Atrial Fibrillation)",
"description": f"High heart rate variability detected (Irregularity: {irregularity_score:.3f})",
"severity": "High" if method == "BioSPPy" else "Moderate",
"recommendation": "ECG and cardiology consultation required",
"source": "Rule-Based Analysis"
})
elif irregularity_score > moderate_threshold:
arrhythmias.append({
"condition": "Mild Rhythm Irregularity",
"description": f"Moderate heart rate variability (Irregularity: {irregularity_score:.3f})",
"severity": "Low",
"recommendation": "Monitor and follow up with healthcare provider",
"source": "Rule-Based Analysis"
})
return arrhythmias
def assess_heart_attack_risk(analysis_results, all_findings):
"""Assess heart attack risk"""
risk_factors = []
risk_score = 0
avg_hr = analysis_results['avg_heart_rate']
heart_rate_data = analysis_results['heart_rate']
irregularity_score = 0
if len(heart_rate_data) > 2:
irregularity_score = np.std(heart_rate_data) / np.mean(heart_rate_data) if np.mean(heart_rate_data) > 0 else 0
if avg_hr > 120:
if avg_hr > 150:
risk_factors.append("Severe tachycardia (>150 BPM) - high cardiac stress")
risk_score += 3
else:
risk_factors.append("Moderate tachycardia (120-150 BPM) - increased cardiac workload")
risk_score += 2
elif avg_hr < 45:
risk_factors.append("Severe bradycardia (<45 BPM) - possible conduction system damage")
risk_score += 2
if irregularity_score > 0.20:
risk_factors.append("High rhythm irregularity suggesting atrial fibrillation")
risk_score += 3
elif irregularity_score > 0.15:
risk_factors.append("Significant rhythm irregularity detected")
risk_score += 2
high_severity_count = len([f for f in all_findings if f.get('severity') == 'High'])
if high_severity_count >= 2:
risk_factors.append(f"Multiple high-severity cardiac abnormalities ({high_severity_count} detected)")
risk_score += 2
if avg_hr > 100 and irregularity_score > 0.15:
risk_factors.append("Combined tachycardia and rhythm irregularity - high cardiac stress")
risk_score += 2
if risk_score >= 6:
risk_level = "HIGH"
risk_description = "Multiple significant cardiac risk factors detected. High risk for cardiac events including myocardial infarction."
risk_recommendation = "URGENT: Seek immediate emergency medical evaluation. Call emergency services if experiencing chest pain, shortness of breath, or other cardiac symptoms."
elif risk_score >= 3:
risk_level = "MODERATE"
risk_description = "Moderate cardiac risk factors present. Elevated risk for future cardiac events."
risk_recommendation = "Schedule urgent cardiology consultation within 24-48 hours. Monitor for cardiac symptoms."
elif risk_score >= 1:
risk_level = "LOW-MODERATE"
risk_description = "Some cardiac risk factors identified. Regular monitoring recommended."
risk_recommendation = "Schedule cardiology follow-up within 1-2 weeks. Lifestyle modifications recommended."
else:
risk_level = "LOW"
risk_description = "No significant acute cardiac risk factors detected based on heart sound analysis."
risk_recommendation = "Continue regular preventive care and heart-healthy lifestyle."
return {
'heart_attack_risk_flag': risk_score > 0,
'risk_level': risk_level,
'risk_score': risk_score,
'risk_factors': risk_factors,
'risk_description': risk_description,
'risk_recommendation': risk_recommendation
}
# ============================================================================
# VISUALIZATION
# ============================================================================
def save_analysis_plot(audio_data, analysis_results, sample_rate, filename):
"""Save analysis plot to file"""
try:
time = np.arange(len(audio_data)) / sample_rate
filtered_signal = analysis_results.get('filtered_signal', audio_data)
filtered_time = np.arange(len(filtered_signal)) / sample_rate
fig, axes = plt.subplots(3, 1, figsize=(15, 10))
# Original signal
axes[0].plot(time, audio_data)
axes[0].set_title('Original Heart Sound Signal')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Amplitude')
axes[0].grid(True)
# Filtered signal with peaks
axes[1].plot(filtered_time, filtered_signal)
s1_peaks = analysis_results.get('s1_peaks', [])
s2_peaks = analysis_results.get('s2_peaks', [])
if len(s1_peaks) > 0:
s1_peaks_array = np.array(s1_peaks) if not isinstance(s1_peaks, np.ndarray) else s1_peaks
s1_times = s1_peaks_array / sample_rate
s1_amps = filtered_signal[s1_peaks_array.astype(int)]
axes[1].scatter(s1_times, s1_amps, color='red', s=50, label='S1 peaks', zorder=5)
if len(s2_peaks) > 0:
s2_peaks_array = np.array(s2_peaks) if not isinstance(s2_peaks, np.ndarray) else s2_peaks
s2_times = s2_peaks_array / sample_rate
s2_amps = filtered_signal[s2_peaks_array.astype(int)]
axes[1].scatter(s2_times, s2_amps, color='blue', s=50, label='S2 peaks', zorder=5)
axes[1].set_title('Filtered Heart Sound Signal with S1/S2 Detection')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Amplitude')
axes[1].legend()
axes[1].grid(True)
# Heart rate
heart_rate = analysis_results.get('heart_rate', [])
if len(heart_rate) > 0:
hr_time = np.linspace(0, len(audio_data)/sample_rate, len(heart_rate))
axes[2].plot(hr_time, heart_rate)
axes[2].set_title('Heart Rate Over Time')
axes[2].set_xlabel('Time (s)')
axes[2].set_ylabel('Heart Rate (BPM)')
axes[2].grid(True)
else:
axes[2].text(0.5, 0.5, 'No heart rate data available',
ha='center', va='center', transform=axes[2].transAxes)
axes[2].set_title('Heart Rate Over Time')
plt.tight_layout()
plt.savefig(filename, dpi=300, bbox_inches='tight')
plt.close()
return True
except Exception as e:
print(f"Error creating plot: {e}")
return False
# ============================================================================
# PYTORCH MODEL
# ============================================================================
class HeartSoundCNN(nn.Module):
def __init__(self, num_classes=4, input_length=60000):
super(HeartSoundCNN, self).__init__()
self.conv1 = nn.Sequential(nn.Conv1d(1, 32, kernel_size=50, stride=2, padding=25), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(4), nn.Dropout(0.2))
self.conv2 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=25, stride=2, padding=12), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(4), nn.Dropout(0.2))
self.conv3 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=10, stride=2, padding=5), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(4), nn.Dropout(0.3))
self.conv4 = nn.Sequential(nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2), nn.BatchNorm1d(256), nn.ReLU(), nn.AdaptiveAvgPool1d(1), nn.Dropout(0.3))
self.fc = nn.Sequential(nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, num_classes))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class HeartSoundPredictor:
def __init__(self, model_path=MODEL_PATH):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_path = model_path
self.class_names = CLASS_NAMES
self.model = None
self._load_model()
def _load_model(self):
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model file not found: {self.model_path}")
checkpoint = torch.load(self.model_path, map_location=self.device)
num_classes = checkpoint.get('num_classes', 4)
input_shape = checkpoint.get('input_shape', (1, 60000))
self.model = HeartSoundCNN(num_classes=num_classes, input_length=input_shape[1])
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
def predict(self, audio_data, sample_rate):
processed_audio = enhanced_preprocess_signal(audio_data, sample_rate)
target_length = SAMPLE_RATE * AUDIO_LENGTH
if len(processed_audio) < target_length:
processed_audio = np.pad(processed_audio, (0, target_length - len(processed_audio)), mode='constant')
else:
processed_audio = processed_audio[:target_length]
audio_tensor = torch.FloatTensor(processed_audio).unsqueeze(0).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(audio_tensor)
probabilities = torch.softmax(outputs, dim=1)
confidence, predicted_class = torch.max(probabilities, dim=1)
predicted_idx = predicted_class.item()
confidence_score = confidence.item()
all_probs = probabilities.squeeze().cpu().numpy()
return {
'predicted_class': self.class_names[predicted_idx],
'confidence': float(confidence_score),
'probabilities': {class_name: float(prob) for class_name, prob in zip(self.class_names, all_probs)},
'method': 'PyTorch CNN'
}
# ============================================================================
# GLOBAL MODEL
# ============================================================================
try:
PREDICTOR = HeartSoundPredictor(MODEL_PATH)
print(f"✓ Model loaded: {MODEL_PATH}")
print(f"✓ Device: {PREDICTOR.device}")
except Exception as e:
print(f"✗ Failed to load model: {e}")
PREDICTOR = None
# ============================================================================
# FASTAPI ENDPOINTS
# ============================================================================
@app.get("/")
def index():
return {
"service": "Heart Sound Analysis API - Comprehensive",
"version": "2.0.0",
"model_loaded": PREDICTOR is not None,
"biosppy_available": BIOSPPY_AVAILABLE,
"usage": "GET /hb?filename=sounds/Heart_Failure_Sound.mp3"
}
@app.get("/health")
def health_check():
return {
"status": "OK",
"model_loaded": PREDICTOR is not None,
"biosppy_available": BIOSPPY_AVAILABLE,
"timestamp": datetime.now().isoformat()
}
@app.get("/hb")
def analyze_heart_sound_endpoint(filename: str):
"""Comprehensive heart sound analysis matching /analysis format"""
try:
if PREDICTOR is None:
raise HTTPException(status_code=503, detail="AI model not loaded")
if not os.path.exists(filename):
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
# Load audio
audio_data, sample_rate = load_audio_file(filename)
if audio_data is None:
raise HTTPException(status_code=500, detail="Failed to load audio")
# AI Prediction
ai_result = PREDICTOR.predict(audio_data, sample_rate)
# Heart sound analysis
analysis_results = analyze_heart_sound(audio_data, sample_rate)
# Generate plot filename
target_name = os.path.splitext(os.path.basename(filename))[0]
plot_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
plot_filename = f"{target_name}_analysis_{plot_timestamp}.png"
# Create output directory if it doesn't exist
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
# Save plot
plot_path = os.path.join(output_dir, plot_filename)
save_analysis_plot(audio_data, analysis_results, sample_rate, plot_path)
# Rule-based detection
arrhythmias = detect_arrhythmias(analysis_results)
# Combine findings
detailed_findings = []
# Add AI finding
detailed_findings.append({
"condition": f"AI Classification: {ai_result['predicted_class'].capitalize()}",
"description": f"Deep learning model prediction with {ai_result['confidence']:.1%} confidence",
"severity": "Moderate" if ai_result['confidence'] > 0.7 else "Low",
"recommendation": f"AI Model: {ai_result['predicted_class'].capitalize()} detected",
"source": "AI Model",
"probabilities": ai_result['probabilities']
})
# Add rule-based findings
detailed_findings.extend(arrhythmias)
# Heart attack risk assessment
heart_attack_risk = assess_heart_attack_risk(analysis_results, detailed_findings)
if heart_attack_risk['heart_attack_risk_flag']:
detailed_findings.append({
"condition": f"Heart Attack Risk Assessment - {heart_attack_risk['risk_level']} RISK",
"description": heart_attack_risk['risk_description'],
"severity": "High" if heart_attack_risk['risk_level'] == "HIGH" else "Moderate",
"recommendation": heart_attack_risk['risk_recommendation'],
"risk_factors": heart_attack_risk['risk_factors'],
"risk_score": heart_attack_risk['risk_score'],
"heart_attack_risk": True,
"source": "Rule-Based Analysis"
})
# Determine overall classification
ai_class_map = {
'normal': 'Normal Heart Sounds',
'murmur': 'Heart Murmur Detected',
'extrahs': 'Extra Heart Sounds Detected',
'artifact': 'Poor Signal Quality'
}
ai_classification = ai_class_map.get(ai_result['predicted_class'], 'Unknown')
if heart_attack_risk['risk_level'] in ['HIGH', 'MODERATE']:
overall_classification = f"ELEVATED CARDIAC RISK - {heart_attack_risk['risk_level']} Heart Attack Risk (AI: {ai_classification})"
overall_confidence = min(0.9, ai_result['confidence'] * 1.05)
else:
overall_classification = ai_classification
overall_confidence = ai_result['confidence']
# Agreement analysis
rule_suggests_abnormal = len([f for f in arrhythmias if f['severity'] in ['High', 'Moderate']]) > 0
ai_suggests_abnormal = ai_result['predicted_class'] != 'normal'
agreement = "High" if rule_suggests_abnormal == ai_suggests_abnormal else "Low"
# Recommendations
recommendations = [
"Analysis Method: Hybrid (AI + Rule-Based)",
f"AI-Rule Agreement: {agreement}"
]
for finding in detailed_findings:
if finding.get('recommendation'):
recommendations.append(finding['recommendation'])
# HRV metrics
heart_rate_data = analysis_results.get('heart_rate', [])
hrv_metrics = {}
if len(heart_rate_data) > 0:
hrv_metrics = {
'hr_std': float(np.std(heart_rate_data)),
'hr_mean': float(np.mean(heart_rate_data)),
'hr_min': float(np.min(heart_rate_data)),
'hr_max': float(np.max(heart_rate_data)),
'irregularity_score': float(np.std(heart_rate_data) / np.mean(heart_rate_data)) if np.mean(heart_rate_data) > 0 else 0
}
# Build response matching /analysis format
response_data = {
'status': 'OK',
'analysis_info': {
'target_file': target_name,
'original_filename': os.path.basename(filename),
'analysis_timestamp': datetime.now().isoformat(),
'analysis_mode': 'Hybrid (AI + Rule-Based)',
'ai_model_used': True,
'hybrid_mode': True,
'plot_filename': plot_filename,
'plot_path': f"/plot/{plot_filename}"
},
'signal_analysis': {
'method': analysis_results.get('method', 'Unknown'),
'sample_rate': analysis_results['sample_rate'],
'signal_duration': len(audio_data) / sample_rate,
'signal_length': len(audio_data)
},
'ai_prediction': {
'predicted_class': ai_result['predicted_class'],
'confidence': round(ai_result['confidence'], 4),
'probabilities': {k: round(v, 4) for k, v in ai_result['probabilities'].items()},
'method': ai_result['method']
},
'classification': {
'result': overall_classification,
'confidence': round(overall_confidence, 4),
'confidence_percentage': f"{overall_confidence:.1%}"
},
'heart_sounds': {
'num_s1': analysis_results['num_s1'],
'num_s2': analysis_results['num_s2'],
's1_peaks': analysis_results['s1_peaks'].tolist() if hasattr(analysis_results['s1_peaks'], 'tolist') else list(analysis_results['s1_peaks']),
's2_peaks': analysis_results['s2_peaks'].tolist() if hasattr(analysis_results['s2_peaks'], 'tolist') else list(analysis_results['s2_peaks']),
's1_s2_ratio': analysis_results['num_s1'] / analysis_results['num_s2'] if analysis_results['num_s2'] > 0 else 0
},
'heart_rate': {
'average_bpm': round(analysis_results['avg_heart_rate'], 2),
'heart_rate_data': heart_rate_data.tolist() if hasattr(heart_rate_data, 'tolist') else list(heart_rate_data),
'num_heart_rate_points': len(heart_rate_data),
**hrv_metrics
},
'detailed_findings': detailed_findings,
'recommendations': recommendations,
'file_context': {
'is_heart_failure_sample': 'heart_failure' in filename.lower() or 'Heart_Failure' in filename,
'is_normal_sample': 'normal' in filename.lower() or 'Normal' in filename,
'expected_findings': []
},
'disclaimers': [
'This is an automated analysis tool for educational/research purposes',
'Results should NOT be used for medical diagnosis or treatment decisions',
'Always consult qualified healthcare professionals for medical concerns',
'The analysis may have false positives/negatives - professional evaluation required'
]
}
return JSONResponse(content=response_data)
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
@app.get("/plot/{filename}")
def get_plot(filename: str):
"""Serve plot images"""
try:
from fastapi.responses import FileResponse
plot_path = os.path.join("output", filename)
if os.path.exists(plot_path):
return FileResponse(plot_path, media_type="image/png")
else:
raise HTTPException(status_code=404, detail="Plot file not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error serving plot: {str(e)}")
@app.on_event("startup")
def startup_event():
print("\n" + "="*70)
print("Heart Sound Analysis API - Comprehensive Standalone")
print("="*70)
print(f"Model: {MODEL_PATH}")
print(f"Model Loaded: {'✓ Yes' if PREDICTOR else '✗ No'}")
print(f"BioSPPy: {'✓ Available' if BIOSPPY_AVAILABLE else '✗ Using Fallback'}")
print(f"Device: {PREDICTOR.device if PREDICTOR else 'N/A'}")
print(f"\nExample: curl 'http://localhost:8000/hb?filename=sounds/Heart_Failure_Sound.mp3'")
print("="*70 + "\n")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)