babyapi / app.py
SwitchAlpha's picture
Update app.py
4d63b6f verified
raw
history blame
17.6 kB
import os
import torch
import torch.nn as nn
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import io
import torchaudio
import librosa
import random
import json
import logging
from contextlib import asynccontextmanager
# [YENİ EKLEME] Hugging Face Transformers kütüphanelerini içe aktarın
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
# --- 1. Günlük Yapılandırması ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 2. Yapılandırma Seçenekleri ---
class Config:
CLASS_NAMES = ['belly_pain', 'burping', 'discomfort', 'hungry']
NUM_CLASSES = len(CLASS_NAMES)
MODEL_PATH_ON_HF_SPACE = "cry_model_quantized_final.ptl"
TARGET_SAMPLE_RATE = 22050
AUDIO_DURATION_SECONDS = 6.0
TARGET_AUDIO_LENGTH = int(AUDIO_DURATION_SECONDS * TARGET_SAMPLE_RATE)
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 128
N_MFCC = 120
N_CHROMA = 12
MODEL_ARCHITECTURE = 'vanilla'
MODEL_BASE_CHANNELS = 32
MODEL_DROPOUT_RATE = 0.4
FUSION_DIM = 256
DEVICE = torch.device('cpu')
THRESHOLD_HIGH_CONFIDENCE = 0.49
THRESHOLD_TWO_CATEGORIES = 0.40
NUM_CARE_POINTS = 2
FAILURE_REDIRECT_STATUS_TEXT = "redirect_to_failure"
# [YENİ EKLEME] ESC-50 "Kapı Bekçisi" Modeli için yapılandırma
ESC50_MODEL_NAME = "bioamla/ast-esc50"
ESC50_TARGET_SR = 16000 # Bu model 16kHz örnekleme hızı bekler
CRY_CONFIDENCE_THRESHOLD = 0.5 # Ağlama olarak kabul etmek için minimum güven eşiği
cfg = Config()
# --- 3. Bakım Noktaları "Veritabanı" ---
CARE_POINTS_DB_RAW = {
"belly_pain": """Bebek olabilir: Karnı rahatsız. [Ana Neden] Hava yutma. Beslendikten hemen sonra geğirme, muhtemelen şişkinlikle birlikte; [Bakım Noktası] Besledikten sonra bebeği dik tutun ve 10-15 dakika dik pozisyonda tutarak gazını çıkarın.---Bebek olabilir: Karnı rahatsız. [Ana Neden] Olgunlaşmamış sindirim sistemi. Yenidoğanlarda sık geğirme, özellikle beslendikten sonra; [Bakım Noktası] Düzenli beslenmeyi sürdürün, aşırı beslemekten kaçının; gaz çıkarmasına yardımcı olmak için beslendikten sonra sırtına hafifçe vurun.---Bebek olabilir: Karnı rahatsız. [Ana Neden] Çok hızlı veya çok fazla yemek. Geğirme, huysuzluk, kusma eşliğinde; [Bakım Noktası] Besleme hızını kontrol edin, daha küçük, daha sık öğünler benimseyin; besleme sırasında uygun şekilde duraklayın.---Bebek olabilir: Karnı rahatsız. [Ana Neden] Duygusal değişiklikler. Ağladıktan veya güldükten sonra geğirme; [Bakım Noktası] Bebeğin duygularını yatıştırın, yoğun ağlama veya heyecandan kaçının; beslenmeden önce ve sonra sakin bir ortam sağlayın.---Bebek olabilir: Karnı rahatsız. [AnaNeden] Sıcaklık değişiklikleri. Ani ortam sıcaklığı değişikliklerinden veya bebek üşüdüğünde geğirme; [Bakım Noktası] Oda sıcaklığını sabit tutun, bebeğin karnının üşümesini önleyin; sıcak tutmaya dikkat edin.""",
"burping": """Besledikten sonra daima bebeğin gazını çıkarın.---Geğirme pozisyonlarını değiştirin, birden fazla yöntem deneyin.---Çok fazla hava yutmasını önlemek için biberon emziği akış hızının uygun olup olmadığını kontrol edin.---Bir seferde yutulan hava miktarını azaltmak için daha küçük, daha sık öğünler sunun.---Eğer bebek sık sık geğiriyorsa, beslenme alışkanlıklarının ayarlanması gerekebilir.""",
"discomfort": """Bezin ıslak veya kirli olup olmadığını kontrol edin.---Oda sıcaklığının rahat olup olmadığını, bebeğin çok sıcak veya çok soğuk giydirilip giydirilmediğini onaylayın.---Bebeğin kıyafetlerinin çok sıkı olup olmadığını veya rahatsızlık verip vermediğini kontrol edin.---Bebeğin böcekler tarafından ısırılmadığından veya cildinin kaşınmadığından emin olun.---Rahatsızlığı hafifletip hafifletmediğini görmek için bebeğin pozisyonunu değiştirmeyi deneyin.""",
"hungry": """Bebek acıkmış olabilir, lütfen zamanında besleyin.---Son beslenme zamanını kontrol edin, yemek zamanı mı?---Biberon veya emzirmeyi deneyin, bebeğin aranma refleksi gösterip göstermediğine bakın.---Bebeğin doyduğundan emin olun, beslendikten sonra bebeğin tepkisini gözlemleyin.---Eğer bebek kontrolsüzce ağlıyorsa, küçük bir ek beslenme gerekebilir."""
}
# (Not: Veritabanı metinlerini daha iyi anlaşılması için İngilizce'den Türkçe'ye çevirdim)
# --- 4. Model Tanımı (Değişmedi) ---
# ... (ResidualBlock ve CryNetMultiBranch sınıflarınız burada, değişiklik yok) ...
# Önceki kodunuzdaki Model Tanımı bölümünü buraya kopyalayın
# (Yorum satırı: Sadelik için model tanımını buraya tekrar eklemedim,
# ama sizin kodunuzda tam olarak burada olmalı)
class ResidualBlock(nn.Module): # Kodu tam hale getirmek için ekliyorum
def __init__(self, in_c, out_c, stride, dr):
super().__init__()
self.conv1 = nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_c)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_c)
self.downsample = nn.Sequential()
if stride != 1 or in_c != out_c:
self.downsample = nn.Sequential(
nn.Conv2d(in_c, out_c, 1, stride, bias=False),
nn.BatchNorm2d(out_c)
)
self.dropout = nn.Dropout(dr)
def forward(self, x):
res = self.downsample(x)
x = self.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
x = self.dropout(x)
return self.relu(x + res)
class CryNetMultiBranch(nn.Module):
def __init__(self, num_classes, config):
super().__init__()
self.config = config
self.mel_encoder = self._create_encoder(1)
self.mfcc_encoder = self._create_encoder(1)
self.chroma_encoder = self._create_encoder(1)
encoder_output_dim = config.MODEL_BASE_CHANNELS * 8 * 3
self.fusion_mlp = nn.Sequential(
nn.Linear(encoder_output_dim, config.FUSION_DIM),
nn.LayerNorm(config.FUSION_DIM),
nn.ReLU(True),
nn.Dropout(config.MODEL_DROPOUT_RATE + 0.1),
nn.Linear(config.FUSION_DIM, num_classes)
)
def _make_layer(self, block, in_c, out_c, blocks, stride, dr):
layers = [block(in_c, out_c, stride, dr)]
[layers.append(block(out_c, out_c, 1, dr)) for _ in range(1, blocks)]
return nn.Sequential(*layers)
def _create_encoder(self, in_channels):
base_channels, dr = self.config.MODEL_BASE_CHANNELS, self.config.MODEL_DROPOUT_RATE
blocks_layer4 = 3 if self.config.MODEL_ARCHITECTURE == 'deeper' else 2
return nn.Sequential(
nn.Conv2d(in_channels, base_channels, 7, 2, 3, bias=False), nn.BatchNorm2d(base_channels),
nn.ReLU(True), nn.MaxPool2d(3, 2, 1),
self._make_layer(ResidualBlock, base_channels, base_channels * 2, 2, 2, dr),
self._make_layer(ResidualBlock, base_channels * 2, base_channels * 4, 2, 2, dr),
self._make_layer(ResidualBlock, base_channels * 4, base_channels * 8, blocks_layer4, 2, dr),
nn.AdaptiveAvgPool2d((1, 1))
)
def forward(self, mel, mfcc, chroma):
features = [self.mel_encoder(mel), self.mfcc_encoder(mfcc), self.chroma_encoder(chroma)]
combined = torch.cat([torch.flatten(f, 1) for f in features], dim=1)
return self.fusion_mlp(combined)
# --- 5. Global Kaynak Başlatma ---
app_globals = {}
# --- 6. Kaynakları lifespan olay yöneticisi ile yükleme ---
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
logging.info("Uygulama başlıyor, kaynaklar yükleniyor...")
# Orijinal modelinizi yükleyin
app_globals["loaded_model"] = torch.jit.load(cfg.MODEL_PATH_ON_HF_SPACE, map_location=cfg.DEVICE)
app_globals["loaded_model"].eval()
logging.info(f"✅ Orijinal Model {cfg.MODEL_PATH_ON_HF_SPACE} başarıyla yüklendi!")
# [YENİ EKLEME] ESC-50 "Kapı Bekçisi" modelini ve çıkarıcıyı yükleyin
app_globals["esc50_model"] = AutoModelForAudioClassification.from_pretrained(cfg.ESC50_MODEL_NAME).to(cfg.DEVICE)
app_globals["esc50_model"].eval()
app_globals["esc50_extractor"] = AutoFeatureExtractor.from_pretrained(cfg.ESC50_MODEL_NAME)
logging.info(f"✅ ESC-50 'Kapı Bekçisi' Modeli {cfg.ESC50_MODEL_NAME} başarıyla yüklendi!")
# Orijinal dönüştürücüleriniz
app_globals["mel_spectrogram_transform"] = torchaudio.transforms.MelSpectrogram(
sample_rate=cfg.TARGET_SAMPLE_RATE, n_fft=cfg.N_FFT, hop_length=cfg.HOP_LENGTH, n_mels=cfg.N_MELS
).to(cfg.DEVICE)
app_globals["mfcc_transform"] = torchaudio.transforms.MFCC(
sample_rate=cfg.TARGET_SAMPLE_RATE, n_mfcc=cfg.N_MFCC, melkwargs={"n_fft": cfg.N_FFT, "hop_length": cfg.HOP_LENGTH}
).to(cfg.DEVICE)
logging.info("✅ Ses dönüştürücüleri başarıyla başlatıldı!")
app_globals["CARE_POINTS_DB"] = {}
for category, raw_text in CARE_POINTS_DB_RAW.items():
points = [point.strip() for point in raw_text.strip().split('---') if point.strip()]
app_globals["CARE_POINTS_DB"][category] = points
logging.info("✅ Bakım bilgi bankası başarıyla ayrıştırıldı!")
app_globals["resampler_cache"] = {}
# [YENİ EKLEME] ESC-50 modeli için ayrı bir yeniden örnekleyici önbelleği
app_globals["esc50_resampler_cache"] = {}
logging.info("✅ Yeniden örnekleyici önbellekleri başarıyla başlatıldı!")
except Exception as e:
logging.error(f"❌ Uygulama başlatılamadı: Kaynaklar yüklenirken hata oluştu: {e}", exc_info=True)
yield
logging.info("Uygulama kapanıyor, kaynaklar temizleniyor...")
app_globals.clear()
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def read_root():
return {"message": "Bebek Ağlaması Sınıflandırma API'si çalışıyor! /predict_cry_audio adresine ses dosyası POST edin"}
# [YENİ EKLEME] Yeniden örnekleyicileri almak için yardımcı fonksiyon
def get_resampler(cache, orig_freq, new_freq, device):
if orig_freq not in cache:
cache[orig_freq] = {}
if new_freq not in cache[orig_freq]:
cache[orig_freq][new_freq] = torchaudio.transforms.Resample(
orig_freq=orig_freq, new_freq=new_freq
).to(device)
return cache[orig_freq][new_freq]
@app.post("/predict_cry_audio")
async def predict_cry_audio(file: UploadFile = File(...)):
if not app_globals.get("loaded_model") or not app_globals.get("esc50_model"):
raise HTTPException(status_code=503, detail="Modeller yüklenmedi, hizmet geçici olarak kullanılamıyor.")
try:
audio_bytes = await file.read()
audio_buffer = io.BytesIO(audio_bytes)
waveform, sample_rate = torchaudio.load(audio_buffer)
waveform = waveform.to(cfg.DEVICE)
# [YENİ EKLEME] Sesi monoya dönüştür (her iki model için de en iyisi)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# --- [YENİ ADIM 1: "KAPI BEKÇİSİ" KONTROLÜ] ---
# Sesi ESC-50 modelinin beklediği 16kHz'e yeniden örnekle
esc_resampler = get_resampler(app_globals["esc50_resampler_cache"], sample_rate, cfg.ESC50_TARGET_SR, cfg.DEVICE)
waveform_16k = esc_resampler(waveform)
# Özellikleri çıkar ve modeli çalıştır
inputs = app_globals["esc50_extractor"](
waveform_16k.cpu().numpy().squeeze(),
sampling_rate=cfg.ESC50_TARGET_SR,
return_tensors="pt"
).to(cfg.DEVICE)
with torch.no_grad():
logits = app_globals["esc50_model"](**inputs).logits
probs = logits.softmax(dim=-1)
top_prob, top_class_idx = torch.max(probs, dim=-1)
top_class_name = app_globals["esc50_model"].config.id2label[top_class_idx.item()]
top_prob_val = top_prob.item()
logging.info(f"ESC-50 'Kapı Bekçisi' sonucu: {top_class_name} (Güven: {top_prob_val:.4f})")
# --- [YENİ ADIM 2: KONTROL VE YÖNLENDİRME] ---
if top_class_name != "crying_baby" or top_prob_val < cfg.CRY_CONFIDENCE_THRESHOLD:
# Bu bir bebek ağlaması değil veya güven çok düşük
return JSONResponse(content={
"status": "no_cry_detected",
"message": "AI, yüklenen seste bir bebek ağlaması tespit etmedi.",
"detected_sound": top_class_name,
"confidence": top_prob_val
})
# --- [ORİJİNAL ADIM: AĞLAMA TİPİ SINIFLANDIRMASI] ---
# (Sadece 'crying_baby' tespit edildiyse bu bölüm çalışır)
logging.info("Bebek ağlaması doğrulandı. Ağlama tipi sınıflandırılıyor...")
# Orijinal modeliniz için 22050Hz'e yeniden örnekleyin
resampler_cache = app_globals["resampler_cache"]
if sample_rate != cfg.TARGET_SAMPLE_RATE:
# [GÜNCELLEME] get_resampler fonksiyonunu kullan
resampler = get_resampler(resampler_cache, sample_rate, cfg.TARGET_SAMPLE_RATE, cfg.DEVICE)
waveform_22k = resampler(waveform)
else:
waveform_22k = waveform # Zaten doğru örnekleme hızında
# Orijinal dolgu/kırpma işleminiz
if waveform_22k.shape[1] < cfg.TARGET_AUDIO_LENGTH:
waveform_22k = torch.nn.functional.pad(waveform_22k, (0, cfg.TARGET_AUDIO_LENGTH - waveform_22k.shape[1]))
else:
waveform_22k = waveform_22k[:, :cfg.TARGET_AUDIO_LENGTH]
# Orijinal özellik çıkarımlarınız (waveform_22k kullanarak)
mel_spec = app_globals["mel_spectrogram_transform"](waveform_22k).unsqueeze(0)
mfcc = app_globals["mfcc_transform"](waveform_22k).unsqueeze(0)
waveform_numpy = waveform_22k.cpu().numpy()[0]
chroma_numpy = librosa.feature.chroma_stft(
y=waveform_numpy, sr=cfg.TARGET_SAMPLE_RATE, n_fft=cfg.N_FFT,
hop_length=cfg.HOP_LENGTH, n_chroma=cfg.N_CHROMA
)
chroma = torch.from_numpy(chroma_numpy).to(cfg.DEVICE)
with torch.no_grad():
if chroma.dim() == 2:
chroma = chroma.unsqueeze(0).unsqueeze(0)
# Orijinal modelinizi çalıştırın
probabilities = app_globals["loaded_model"](mel_spec, mfcc, chroma)
probabilities_list = probabilities.squeeze().tolist()
confidences_data = [{"label": cfg.CLASS_NAMES[i], "score": round(score, 4)} for i, score in enumerate(probabilities_list)]
sorted_confidences = sorted(confidences_data, key=lambda x: x['score'], reverse=True)
top1_label = sorted_confidences[0]['label']
top1_score = sorted_confidences[0]['score']
except Exception as e:
logging.error(f"Ses dosyası işlenirken kritik hata: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Ses dosyası işlenemedi, lütfen tekrar deneyin veya dosya formatını kontrol edin.")
# Orijinal JSON yanıt mantığınız (hiç değişmedi)
response_data = {}
CARE_POINTS_DB = app_globals["CARE_POINTS_DB"]
if top1_score > cfg.THRESHOLD_HIGH_CONFIDENCE:
care_points = CARE_POINTS_DB.get(top1_label, [])
selected_points = random.sample(care_points, min(len(care_points), cfg.NUM_CARE_POINTS))
response_data = {
"status": "success_single",
"title": f"Bebek muhtemelen şundan ağlıyor: {top1_label}",
"subtitle": "Ana nedenler ve bakım noktaları:",
"care_points": selected_points,
"promo_text": "Daha detaylı rehberlik için üyelik satın alabilirsiniz",
"all_scores": sorted_confidences
}
elif len(sorted_confidences) > 1 and sorted_confidences[1]['score'] >= cfg.THRESHOLD_TWO_CATEGORIES:
top2_label = sorted_confidences[1]['label']
combined_points = []
care_points_1 = CARE_POINTS_DB.get(top1_label, [])
if care_points_1: combined_points.append(random.choice(care_points_1))
care_points_2 = CARE_POINTS_DB.get(top2_label, [])
if care_points_2: combined_points.append(random.choice(care_points_2))
response_data = {
"status": "success_multiple",
"title": f"Bebek muhtemelen şundan ağlıyor: {top1_label} veya {top2_label}",
"subtitle": "Ana nedenler ve bakım noktaları:",
"care_points": combined_points,
"promo_text": "Daha detaylı rehberlik için üyelik satın alabilirsiniz",
"all_scores": sorted_confidences
}
else:
response_data = {
"status": cfg.FAILURE_REDIRECT_STATUS_TEXT,
"message": "AI net bir neden belirleyemedi, sizin için genel yatıştırma yönergeleri hazırladık.",
"all_scores": sorted_confidences
}
return JSONResponse(content=response_data)
# --- 7. Uygulama Başlatıcı ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)