cough_ai / utils /inference.py
QnxprU69yCNg8XJ
we can start now
44626f7
raw
history blame contribute delete
890 Bytes
import torch
from transformers import AutoFeatureExtractor, AutoModel
# Charger HeAR depuis Hugging Face
MODEL_NAME = "google/hear"
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
hear_model = AutoModel.from_pretrained(MODEL_NAME)
hear_model.eval()
def get_embeddings(waveform, sr):
"""
Transforme un audio en embeddings HeAR
"""
# Transformer waveform en input pour HeAR
inputs = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
embeddings = hear_model(**inputs).last_hidden_state
# embeddings : (batch, sequence, features)
return embeddings
def predict_risk(embeddings):
"""
Pour le moment : score factice basé sur la norme L2 des embeddings
Plus tard : utiliser ton classifieur pneumonie
"""
score = torch.norm(embeddings, dim=-1).mean().item()
return score