File size: 8,453 Bytes
7a09e17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""
OmniCoreX Real-Time Inference Pipeline

This module implements a super advanced, ultra high-tech, real-time inference pipeline for OmniCoreX,
supporting streaming inputs, adaptive response generation, and dynamic decision making.

Features:
- Streaming input handling with buffer and timeout control.
- Adaptive context management with sliding window history.
- Efficient batching and asynchronous execution for low latency.
- Integration with model's decision-making modules.
- Support for multi-modal inputs and outputs.
- Highly configurable inference parameters.
"""

import time
import threading
import queue
from typing import Dict, Optional, List, Any, Callable
import torch
import torch.nn.functional as F

class StreamingInference:
    def __init__(self,
                 model: torch.nn.Module,
                 tokenizer: Optional[Callable[[str], List[int]]] = None,
                 device: Optional[torch.device] = None,
                 max_context_length: int = 512,
                 max_response_length: int = 128,
                 streaming_timeout: float = 2.0,
                 batch_size: int = 1):
        """
        Initialize the real-time streaming inference pipeline.

        Args:
            model: OmniCoreX model instance.
            tokenizer: Optional tokenizer for input preprocessing.
            device: Device to run inference on.
            max_context_length: Max tokens in context window.
            max_response_length: Max tokens in generated response.
            streaming_timeout: Max seconds to wait for input buffering.
            batch_size: Batch size for inference.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
        self.max_context_length = max_context_length
        self.max_response_length = max_response_length
        self.streaming_timeout = streaming_timeout
        self.batch_size = batch_size

        self.model.to(self.device)
        self.model.eval()

        self.input_queue = queue.Queue()
        self.output_queue = queue.Queue()

        self.context_history: List[str] = []
        self.lock = threading.Lock()

        self._stop_event = threading.Event()
        self._thread = threading.Thread(target=self._inference_loop, daemon=True)

    def start(self):
        """Start the background inference processing thread."""
        self._stop_event.clear()
        if not self._thread.is_alive():
            self._thread = threading.Thread(target=self._inference_loop, daemon=True)
            self._thread.start()

    def stop(self):
        """Stop the inference processing thread."""
        self._stop_event.set()
        self._thread.join(timeout=5)

    def submit_input(self, input_text: str):
        """
        Submit streaming input text for inference.

        Args:
            input_text: Incoming user or sensor input string.
        """
        self.input_queue.put(input_text)

    def get_response(self, timeout: Optional[float] = None) -> Optional[str]:
        """
        Retrieve the next generated response from the output queue.

        Args:
            timeout: Seconds to wait for response.

        Returns:
            Generated string response or None if timeout.
        """
        try:
            response = self.output_queue.get(timeout=timeout)
            return response
        except queue.Empty:
            return None

    def _encode_context(self, context_texts: List[str]) -> torch.Tensor:
        """
        Converts list of context sentences into token tensor for model input.

        Args:
            context_texts: List of text strings.

        Returns:
            Tensor of shape (1, seq_len) on device.
        """
        if self.tokenizer is None:
            raise RuntimeError("Tokenizer must be provided for text encoding.")
        full_text = " ".join(context_texts)
        token_ids = self.tokenizer(full_text)
        token_ids = token_ids[-self.max_context_length:]
        input_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device)
        return input_tensor

    @torch.no_grad()
    def _generate_response(self, input_tensor: torch.Tensor) -> str:
        """
        Generates text response from model given input tokens.

        Args:
            input_tensor: Tensor of token ids shape (1, seq_len).

        Returns:
            Generated string response.
        """
        outputs = self.model(input_tensor)  # Expected output shape (1, seq_len, vocab_size)
        logits = outputs[0, -self.max_response_length:, :]  # Take last tokens logits
        probabilities = F.softmax(logits, dim=-1)
        token_ids = torch.multinomial(probabilities, num_samples=1).squeeze(-1).cpu().tolist()

        if self.tokenizer and hasattr(self.tokenizer, "decode"):
            response = self.tokenizer.decode(token_ids)
        else:
            # Fallback: Map token ids to chars mod 256 (dummy)
            response = "".join([chr(t % 256) for t in token_ids])
        return response

    def _inference_loop(self):
        """
        Background thread to process inputs, maintain context, and generate outputs.
        """
        buffer = []
        last_input_time = time.time()

        while not self._stop_event.is_set():
            try:
                # Wait for input or timeout
                timed_out = False
                while True:
                    try:
                        inp = self.input_queue.get(timeout=0.1)
                        buffer.append(inp)
                        last_input_time = time.time()
                    except queue.Empty:
                        if time.time() - last_input_time > self.streaming_timeout:
                            timed_out = True
                        break

                if len(buffer) == 0 and not timed_out:
                    continue

                if timed_out or len(buffer) >= self.batch_size:
                    with self.lock:
                        # Update running context history with new buffer inputs
                        self.context_history.extend(buffer)
                        # Restrict context history length (simple sliding window)
                        if len(self.context_history) > 20:
                            self.context_history = self.context_history[-20:]
                        cur_context = self.context_history.copy()
                        buffer.clear()

                    # Encode context and generate response
                    input_tensor = self._encode_context(cur_context)
                    response = self._generate_response(input_tensor)

                    # Append response to context history
                    with self.lock:
                        self.context_history.append(response)

                    self.output_queue.put(response)

            except Exception as e:
                print(f"[Inference] Exception in inference loop: {e}")

        print("[Inference] Stopped inference loop.")

if __name__ == "__main__":
    # Minimal example using dummy tokenizer and dummy model for demonstration.

    class DummyTokenizer:
        def __call__(self, text):
            # Simple char to token id mapping (mod 100 + 1)
            return [ord(c) % 100 + 1 for c in text]
        def decode(self, token_ids):
            return "".join(chr((tid - 1) % 100 + 32) for tid in token_ids)

    class DummyModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.vocab_size = 128
        def forward(self, x):
            batch_size, seq_len = x.shape
            # Return random logits tensor: (batch, seq_len, vocab_size)
            logits = torch.randn(batch_size, seq_len, self.vocab_size)
            return logits

    tokenizer = DummyTokenizer()
    model = DummyModel()

    inference_engine = StreamingInference(model=model, tokenizer=tokenizer, max_context_length=50)
    inference_engine.start()

    test_inputs = [
        "Hello, OmniCoreX! ",
        "How are you today? ",
        "Generate a super intelligent response."
    ]

    for inp in test_inputs:
        print(f">> Input: {inp.strip()}")
        inference_engine.submit_input(inp)
        time.sleep(0.5)
        output = inference_engine.get_response(timeout=5.0)
        if output:
            print(f"<< Response: {output}")

    inference_engine.stop()