|
|
"""
|
|
|
Multi-modal model inference runner using Helium virtual GPU.
|
|
|
"""
|
|
|
import os
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
import json
|
|
|
import time
|
|
|
import numpy as np
|
|
|
import cv2
|
|
|
import soundfile as sf
|
|
|
from typing import Dict, List, Optional, Union, Any
|
|
|
|
|
|
|
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
sys.path.insert(0, root_dir)
|
|
|
sys.path.insert(0, os.path.join(root_dir, 'virtual_gpu_driver'))
|
|
|
from inference.app import MultiModalModel
|
|
|
|
|
|
class InferenceRunner:
|
|
|
"""Handles model inference with efficient batching and caching"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
model_id: str = "openai-oss-20b",
|
|
|
device_id: str = "vgpu0",
|
|
|
batch_size: int = 1,
|
|
|
cache_dir: Optional[str] = None
|
|
|
):
|
|
|
|
|
|
self.model = MultiModalModel.from_pretrained(
|
|
|
model_id=model_id,
|
|
|
device_id=device_id,
|
|
|
cache_dir=cache_dir
|
|
|
)
|
|
|
self.model.eval()
|
|
|
|
|
|
self.batch_size = batch_size
|
|
|
self.cache_dir = cache_dir
|
|
|
if cache_dir:
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
self._text_batch = []
|
|
|
self._image_batch = []
|
|
|
self._audio_batch = []
|
|
|
|
|
|
def preprocess_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
|
|
"""Preprocess text input"""
|
|
|
if isinstance(text, str):
|
|
|
text = [text]
|
|
|
|
|
|
|
|
|
|
|
|
text_data = np.zeros((len(text), 64), dtype=np.int32)
|
|
|
return text_data
|
|
|
|
|
|
def preprocess_image(self, image_path: str) -> np.ndarray:
|
|
|
"""Preprocess image input"""
|
|
|
|
|
|
image = cv2.imread(image_path)
|
|
|
if image is None:
|
|
|
raise ValueError(f"Could not load image: {image_path}")
|
|
|
|
|
|
|
|
|
image = cv2.resize(image, (224, 224))
|
|
|
|
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
image = image.astype(np.float32) / 127.5 - 1.0
|
|
|
|
|
|
|
|
|
image = image.transpose(2, 0, 1)[None]
|
|
|
return image
|
|
|
|
|
|
def preprocess_audio(self, audio_path: str) -> np.ndarray:
|
|
|
"""Preprocess audio input"""
|
|
|
|
|
|
audio, sr = sf.read(audio_path)
|
|
|
|
|
|
|
|
|
if len(audio.shape) > 1:
|
|
|
audio = audio.mean(axis=1)
|
|
|
|
|
|
|
|
|
audio = audio / np.abs(audio).max()
|
|
|
|
|
|
|
|
|
target_len = 16000
|
|
|
if len(audio) > target_len:
|
|
|
audio = audio[:target_len]
|
|
|
else:
|
|
|
audio = np.pad(audio, (0, target_len - len(audio)))
|
|
|
|
|
|
return audio[None, None]
|
|
|
|
|
|
def add_to_batch(
|
|
|
self,
|
|
|
text: Optional[Union[str, List[str]]] = None,
|
|
|
image_path: Optional[str] = None,
|
|
|
audio_path: Optional[str] = None
|
|
|
) -> None:
|
|
|
"""Add inputs to batch for processing"""
|
|
|
if text is not None:
|
|
|
text_data = self.preprocess_text(text)
|
|
|
self._text_batch.append(text_data)
|
|
|
|
|
|
if image_path is not None:
|
|
|
image_data = self.preprocess_image(image_path)
|
|
|
self._image_batch.append(image_data)
|
|
|
|
|
|
if audio_path is not None:
|
|
|
audio_data = self.preprocess_audio(audio_path)
|
|
|
self._audio_batch.append(audio_data)
|
|
|
|
|
|
|
|
|
if len(self._text_batch) >= self.batch_size:
|
|
|
self.process_batch()
|
|
|
|
|
|
def process_batch(self) -> Dict[str, np.ndarray]:
|
|
|
"""Process current batch through model"""
|
|
|
if not any([self._text_batch, self._image_batch, self._audio_batch]):
|
|
|
return {}
|
|
|
|
|
|
|
|
|
inputs = {}
|
|
|
if self._text_batch:
|
|
|
inputs["text"] = np.concatenate(self._text_batch, axis=0)
|
|
|
if self._image_batch:
|
|
|
inputs["image"] = np.concatenate(self._image_batch, axis=0)
|
|
|
if self._audio_batch:
|
|
|
inputs["audio"] = np.concatenate(self._audio_batch, axis=0)
|
|
|
|
|
|
|
|
|
outputs = self.model(inputs)
|
|
|
|
|
|
|
|
|
if self.cache_dir:
|
|
|
timestamp = str(int(time.time()))
|
|
|
cache_path = os.path.join(self.cache_dir, f"batch_{timestamp}.npz")
|
|
|
np.savez(cache_path, **outputs)
|
|
|
|
|
|
|
|
|
self._text_batch.clear()
|
|
|
self._image_batch.clear()
|
|
|
self._audio_batch.clear()
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
def generate_from_context(
|
|
|
self,
|
|
|
context_text: Optional[str] = None,
|
|
|
context_image: Optional[str] = None,
|
|
|
context_audio: Optional[str] = None,
|
|
|
max_length: int = 100
|
|
|
) -> np.ndarray:
|
|
|
"""Generate sequence using multi-modal context"""
|
|
|
inputs = {}
|
|
|
|
|
|
if context_text:
|
|
|
inputs["text"] = self.preprocess_text(context_text)
|
|
|
if context_image:
|
|
|
inputs["image"] = self.preprocess_image(context_image)
|
|
|
if context_audio:
|
|
|
inputs["audio"] = self.preprocess_audio(context_audio)
|
|
|
|
|
|
return self.model.generate(inputs, max_length=max_length)
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
text: Optional[Union[str, List[str]]] = None,
|
|
|
image_path: Optional[str] = None,
|
|
|
audio_path: Optional[str] = None,
|
|
|
return_dict: bool = True
|
|
|
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
|
"""Run inference on inputs"""
|
|
|
self.add_to_batch(text, image_path, audio_path)
|
|
|
return self.process_batch()
|
|
|
|
|
|
def cleanup(self):
|
|
|
"""Clean up resources"""
|
|
|
self.model.cleanup()
|
|
|
if self.cache_dir and os.path.exists(self.cache_dir):
|
|
|
import shutil
|
|
|
shutil.rmtree(self.cache_dir)
|
|
|
|
|
|
def main():
|
|
|
"""Example usage"""
|
|
|
|
|
|
runner = InferenceRunner(
|
|
|
model_id="openai-oss-20b",
|
|
|
device_id="vgpu0",
|
|
|
batch_size=4,
|
|
|
cache_dir="inference_cache"
|
|
|
)
|
|
|
|
|
|
|
|
|
outputs = runner(
|
|
|
text="A photo of a cat",
|
|
|
image_path="cat.jpg",
|
|
|
audio_path="meow.wav"
|
|
|
)
|
|
|
print("Single inference outputs:", outputs.keys())
|
|
|
|
|
|
|
|
|
for i in range(10):
|
|
|
runner.add_to_batch(
|
|
|
text=f"Sample text {i}",
|
|
|
image_path=f"image_{i}.jpg",
|
|
|
audio_path=f"audio_{i}.wav"
|
|
|
)
|
|
|
|
|
|
|
|
|
final_outputs = runner.process_batch()
|
|
|
print("Batch processing outputs:", final_outputs.keys())
|
|
|
|
|
|
|
|
|
generated = runner.generate_from_context(
|
|
|
context_text="Describe this image and sound:",
|
|
|
context_image="scene.jpg",
|
|
|
context_audio="ambience.wav",
|
|
|
max_length=50
|
|
|
)
|
|
|
print("Generated sequence shape:", generated.shape)
|
|
|
|
|
|
|
|
|
runner.cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|