OmniCoreX / inference.py
Kosasih's picture
Create inference.py
7a09e17 verified
"""
OmniCoreX Real-Time Inference Pipeline
This module implements a super advanced, ultra high-tech, real-time inference pipeline for OmniCoreX,
supporting streaming inputs, adaptive response generation, and dynamic decision making.
Features:
- Streaming input handling with buffer and timeout control.
- Adaptive context management with sliding window history.
- Efficient batching and asynchronous execution for low latency.
- Integration with model's decision-making modules.
- Support for multi-modal inputs and outputs.
- Highly configurable inference parameters.
"""
import time
import threading
import queue
from typing import Dict, Optional, List, Any, Callable
import torch
import torch.nn.functional as F
class StreamingInference:
def __init__(self,
model: torch.nn.Module,
tokenizer: Optional[Callable[[str], List[int]]] = None,
device: Optional[torch.device] = None,
max_context_length: int = 512,
max_response_length: int = 128,
streaming_timeout: float = 2.0,
batch_size: int = 1):
"""
Initialize the real-time streaming inference pipeline.
Args:
model: OmniCoreX model instance.
tokenizer: Optional tokenizer for input preprocessing.
device: Device to run inference on.
max_context_length: Max tokens in context window.
max_response_length: Max tokens in generated response.
streaming_timeout: Max seconds to wait for input buffering.
batch_size: Batch size for inference.
"""
self.model = model
self.tokenizer = tokenizer
self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
self.max_context_length = max_context_length
self.max_response_length = max_response_length
self.streaming_timeout = streaming_timeout
self.batch_size = batch_size
self.model.to(self.device)
self.model.eval()
self.input_queue = queue.Queue()
self.output_queue = queue.Queue()
self.context_history: List[str] = []
self.lock = threading.Lock()
self._stop_event = threading.Event()
self._thread = threading.Thread(target=self._inference_loop, daemon=True)
def start(self):
"""Start the background inference processing thread."""
self._stop_event.clear()
if not self._thread.is_alive():
self._thread = threading.Thread(target=self._inference_loop, daemon=True)
self._thread.start()
def stop(self):
"""Stop the inference processing thread."""
self._stop_event.set()
self._thread.join(timeout=5)
def submit_input(self, input_text: str):
"""
Submit streaming input text for inference.
Args:
input_text: Incoming user or sensor input string.
"""
self.input_queue.put(input_text)
def get_response(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Retrieve the next generated response from the output queue.
Args:
timeout: Seconds to wait for response.
Returns:
Generated string response or None if timeout.
"""
try:
response = self.output_queue.get(timeout=timeout)
return response
except queue.Empty:
return None
def _encode_context(self, context_texts: List[str]) -> torch.Tensor:
"""
Converts list of context sentences into token tensor for model input.
Args:
context_texts: List of text strings.
Returns:
Tensor of shape (1, seq_len) on device.
"""
if self.tokenizer is None:
raise RuntimeError("Tokenizer must be provided for text encoding.")
full_text = " ".join(context_texts)
token_ids = self.tokenizer(full_text)
token_ids = token_ids[-self.max_context_length:]
input_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device)
return input_tensor
@torch.no_grad()
def _generate_response(self, input_tensor: torch.Tensor) -> str:
"""
Generates text response from model given input tokens.
Args:
input_tensor: Tensor of token ids shape (1, seq_len).
Returns:
Generated string response.
"""
outputs = self.model(input_tensor) # Expected output shape (1, seq_len, vocab_size)
logits = outputs[0, -self.max_response_length:, :] # Take last tokens logits
probabilities = F.softmax(logits, dim=-1)
token_ids = torch.multinomial(probabilities, num_samples=1).squeeze(-1).cpu().tolist()
if self.tokenizer and hasattr(self.tokenizer, "decode"):
response = self.tokenizer.decode(token_ids)
else:
# Fallback: Map token ids to chars mod 256 (dummy)
response = "".join([chr(t % 256) for t in token_ids])
return response
def _inference_loop(self):
"""
Background thread to process inputs, maintain context, and generate outputs.
"""
buffer = []
last_input_time = time.time()
while not self._stop_event.is_set():
try:
# Wait for input or timeout
timed_out = False
while True:
try:
inp = self.input_queue.get(timeout=0.1)
buffer.append(inp)
last_input_time = time.time()
except queue.Empty:
if time.time() - last_input_time > self.streaming_timeout:
timed_out = True
break
if len(buffer) == 0 and not timed_out:
continue
if timed_out or len(buffer) >= self.batch_size:
with self.lock:
# Update running context history with new buffer inputs
self.context_history.extend(buffer)
# Restrict context history length (simple sliding window)
if len(self.context_history) > 20:
self.context_history = self.context_history[-20:]
cur_context = self.context_history.copy()
buffer.clear()
# Encode context and generate response
input_tensor = self._encode_context(cur_context)
response = self._generate_response(input_tensor)
# Append response to context history
with self.lock:
self.context_history.append(response)
self.output_queue.put(response)
except Exception as e:
print(f"[Inference] Exception in inference loop: {e}")
print("[Inference] Stopped inference loop.")
if __name__ == "__main__":
# Minimal example using dummy tokenizer and dummy model for demonstration.
class DummyTokenizer:
def __call__(self, text):
# Simple char to token id mapping (mod 100 + 1)
return [ord(c) % 100 + 1 for c in text]
def decode(self, token_ids):
return "".join(chr((tid - 1) % 100 + 32) for tid in token_ids)
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vocab_size = 128
def forward(self, x):
batch_size, seq_len = x.shape
# Return random logits tensor: (batch, seq_len, vocab_size)
logits = torch.randn(batch_size, seq_len, self.vocab_size)
return logits
tokenizer = DummyTokenizer()
model = DummyModel()
inference_engine = StreamingInference(model=model, tokenizer=tokenizer, max_context_length=50)
inference_engine.start()
test_inputs = [
"Hello, OmniCoreX! ",
"How are you today? ",
"Generate a super intelligent response."
]
for inp in test_inputs:
print(f">> Input: {inp.strip()}")
inference_engine.submit_input(inp)
time.sleep(0.5)
output = inference_engine.get_response(timeout=5.0)
if output:
print(f"<< Response: {output}")
inference_engine.stop()