| from typing import Dict, List, Any |
| from NatureLM.models import NatureLM |
| from NatureLM.infer import Pipeline |
| import numpy as np |
| import torch |
| import os |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| auth_token = os.environ.get("LLAMA_TOK") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio", device=device, hf_auth_token=auth_token) |
| self.model = model.eval().to(device) |
|
|
| cfg_path = "inference.yml" |
| if not os.path.exists(cfg_path): |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| cfg_path = os.path.join(script_dir, "inference.yml") |
| |
| if not os.path.exists(cfg_path): |
| raise FileNotFoundError(f"inference.yml not found at {cfg_path}. Current directory contents: {os.listdir('.')}") |
| |
| self.pipeline = Pipeline(model=self.model, cfg_path=cfg_path) |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Process audio list of floats with NatureLM-audio model. |
| |
| Parameters |
| ---------- |
| data : Dict[str, Any] |
| Dictionary containing: |
| - inputs : list[float] |
| Audio data as list of floats |
| - query : str |
| Question to ask about the audio |
| - sample_rate : int, optional |
| Audio sample rate, default 16000 |
| |
| Returns |
| ------- |
| List[Dict[str, Any]] |
| List containing result dictionary with 'result' and 'query' keys, |
| or error dictionary with 'error' key |
| """ |
| |
| audio = data.get("inputs") |
| query = data.get("query", "") |
| sample_rate = data.get("sample_rate", 16000) |
| |
| if audio is None: |
| return [{"error": "No audio data provided"}] |
| |
| if not query: |
| return [{"error": "No query provided"}] |
| |
| |
| if isinstance(audio, list): |
| audio = np.array(audio, dtype=np.float32) |
| elif not isinstance(audio, np.ndarray): |
| return [{"error": f"Audio data must be a list or numpy array, got {type(audio)}"}] |
| |
| try: |
| |
| results = self.pipeline( |
| audios=[audio], |
| queries=query, |
| input_sample_rate=sample_rate |
| ) |
| |
| return [{"result": results, "query": query}] |
| |
| except Exception as e: |
| return [{"error": f"Error processing audio: {str(e)}"}] |