Safetensors
EarthSpeciesProject
NatureLM
NatureLM-audio / handler.py
Cheeky Sparrow
update handler
8172885
from typing import Dict, List, Any
from NatureLM.models import NatureLM
from NatureLM.infer import Pipeline
import numpy as np
import torch
import os
class EndpointHandler():
def __init__(self, path=""):
auth_token = os.environ.get("LLAMA_TOK")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio", device=device, hf_auth_token=auth_token)
self.model = model.eval().to(device)
cfg_path = "inference.yml"
if not os.path.exists(cfg_path):
script_dir = os.path.dirname(os.path.abspath(__file__))
cfg_path = os.path.join(script_dir, "inference.yml")
if not os.path.exists(cfg_path):
raise FileNotFoundError(f"inference.yml not found at {cfg_path}. Current directory contents: {os.listdir('.')}")
self.pipeline = Pipeline(model=self.model, cfg_path=cfg_path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process audio list of floats with NatureLM-audio model.
Parameters
----------
data : Dict[str, Any]
Dictionary containing:
- inputs : list[float]
Audio data as list of floats
- query : str
Question to ask about the audio
- sample_rate : int, optional
Audio sample rate, default 16000
Returns
-------
List[Dict[str, Any]]
List containing result dictionary with 'result' and 'query' keys,
or error dictionary with 'error' key
"""
audio = data.get("inputs")
query = data.get("query", "")
sample_rate = data.get("sample_rate", 16000)
if audio is None:
return [{"error": "No audio data provided"}]
if not query:
return [{"error": "No query provided"}]
# Convert list to numpy array if needed (when sent via JSON)
if isinstance(audio, list):
audio = np.array(audio, dtype=np.float32)
elif not isinstance(audio, np.ndarray):
return [{"error": f"Audio data must be a list or numpy array, got {type(audio)}"}]
try:
# Run inference using the pipeline
results = self.pipeline(
audios=[audio],
queries=query,
input_sample_rate=sample_rate
)
return [{"result": results, "query": query}]
except Exception as e:
return [{"error": f"Error processing audio: {str(e)}"}]