|
|
from typing import Dict, List, Any |
|
|
from NatureLM.models import NatureLM |
|
|
from NatureLM.infer import Pipeline |
|
|
import numpy as np |
|
|
import torch |
|
|
import os |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
auth_token = os.environ.get("LLAMA_TOK") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio", device=device, hf_auth_token=auth_token) |
|
|
self.model = model.eval().to(device) |
|
|
|
|
|
cfg_path = "inference.yml" |
|
|
if not os.path.exists(cfg_path): |
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
cfg_path = os.path.join(script_dir, "inference.yml") |
|
|
|
|
|
if not os.path.exists(cfg_path): |
|
|
raise FileNotFoundError(f"inference.yml not found at {cfg_path}. Current directory contents: {os.listdir('.')}") |
|
|
|
|
|
self.pipeline = Pipeline(model=self.model, cfg_path=cfg_path) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process audio list of floats with NatureLM-audio model. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : Dict[str, Any] |
|
|
Dictionary containing: |
|
|
- inputs : list[float] |
|
|
Audio data as list of floats |
|
|
- query : str |
|
|
Question to ask about the audio |
|
|
- sample_rate : int, optional |
|
|
Audio sample rate, default 16000 |
|
|
|
|
|
Returns |
|
|
------- |
|
|
List[Dict[str, Any]] |
|
|
List containing result dictionary with 'result' and 'query' keys, |
|
|
or error dictionary with 'error' key |
|
|
""" |
|
|
|
|
|
audio = data.get("inputs") |
|
|
query = data.get("query", "") |
|
|
sample_rate = data.get("sample_rate", 16000) |
|
|
|
|
|
if audio is None: |
|
|
return [{"error": "No audio data provided"}] |
|
|
|
|
|
if not query: |
|
|
return [{"error": "No query provided"}] |
|
|
|
|
|
|
|
|
if isinstance(audio, list): |
|
|
audio = np.array(audio, dtype=np.float32) |
|
|
elif not isinstance(audio, np.ndarray): |
|
|
return [{"error": f"Audio data must be a list or numpy array, got {type(audio)}"}] |
|
|
|
|
|
try: |
|
|
|
|
|
results = self.pipeline( |
|
|
audios=[audio], |
|
|
queries=query, |
|
|
input_sample_rate=sample_rate |
|
|
) |
|
|
|
|
|
return [{"result": results, "query": query}] |
|
|
|
|
|
except Exception as e: |
|
|
return [{"error": f"Error processing audio: {str(e)}"}] |