Kosasih commited on
Commit
7a09e17
·
verified ·
1 Parent(s): c29d461

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +231 -0
inference.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OmniCoreX Real-Time Inference Pipeline
3
+
4
+ This module implements a super advanced, ultra high-tech, real-time inference pipeline for OmniCoreX,
5
+ supporting streaming inputs, adaptive response generation, and dynamic decision making.
6
+
7
+ Features:
8
+ - Streaming input handling with buffer and timeout control.
9
+ - Adaptive context management with sliding window history.
10
+ - Efficient batching and asynchronous execution for low latency.
11
+ - Integration with model's decision-making modules.
12
+ - Support for multi-modal inputs and outputs.
13
+ - Highly configurable inference parameters.
14
+ """
15
+
16
+ import time
17
+ import threading
18
+ import queue
19
+ from typing import Dict, Optional, List, Any, Callable
20
+ import torch
21
+ import torch.nn.functional as F
22
+
23
+ class StreamingInference:
24
+ def __init__(self,
25
+ model: torch.nn.Module,
26
+ tokenizer: Optional[Callable[[str], List[int]]] = None,
27
+ device: Optional[torch.device] = None,
28
+ max_context_length: int = 512,
29
+ max_response_length: int = 128,
30
+ streaming_timeout: float = 2.0,
31
+ batch_size: int = 1):
32
+ """
33
+ Initialize the real-time streaming inference pipeline.
34
+
35
+ Args:
36
+ model: OmniCoreX model instance.
37
+ tokenizer: Optional tokenizer for input preprocessing.
38
+ device: Device to run inference on.
39
+ max_context_length: Max tokens in context window.
40
+ max_response_length: Max tokens in generated response.
41
+ streaming_timeout: Max seconds to wait for input buffering.
42
+ batch_size: Batch size for inference.
43
+ """
44
+ self.model = model
45
+ self.tokenizer = tokenizer
46
+ self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
47
+ self.max_context_length = max_context_length
48
+ self.max_response_length = max_response_length
49
+ self.streaming_timeout = streaming_timeout
50
+ self.batch_size = batch_size
51
+
52
+ self.model.to(self.device)
53
+ self.model.eval()
54
+
55
+ self.input_queue = queue.Queue()
56
+ self.output_queue = queue.Queue()
57
+
58
+ self.context_history: List[str] = []
59
+ self.lock = threading.Lock()
60
+
61
+ self._stop_event = threading.Event()
62
+ self._thread = threading.Thread(target=self._inference_loop, daemon=True)
63
+
64
+ def start(self):
65
+ """Start the background inference processing thread."""
66
+ self._stop_event.clear()
67
+ if not self._thread.is_alive():
68
+ self._thread = threading.Thread(target=self._inference_loop, daemon=True)
69
+ self._thread.start()
70
+
71
+ def stop(self):
72
+ """Stop the inference processing thread."""
73
+ self._stop_event.set()
74
+ self._thread.join(timeout=5)
75
+
76
+ def submit_input(self, input_text: str):
77
+ """
78
+ Submit streaming input text for inference.
79
+
80
+ Args:
81
+ input_text: Incoming user or sensor input string.
82
+ """
83
+ self.input_queue.put(input_text)
84
+
85
+ def get_response(self, timeout: Optional[float] = None) -> Optional[str]:
86
+ """
87
+ Retrieve the next generated response from the output queue.
88
+
89
+ Args:
90
+ timeout: Seconds to wait for response.
91
+
92
+ Returns:
93
+ Generated string response or None if timeout.
94
+ """
95
+ try:
96
+ response = self.output_queue.get(timeout=timeout)
97
+ return response
98
+ except queue.Empty:
99
+ return None
100
+
101
+ def _encode_context(self, context_texts: List[str]) -> torch.Tensor:
102
+ """
103
+ Converts list of context sentences into token tensor for model input.
104
+
105
+ Args:
106
+ context_texts: List of text strings.
107
+
108
+ Returns:
109
+ Tensor of shape (1, seq_len) on device.
110
+ """
111
+ if self.tokenizer is None:
112
+ raise RuntimeError("Tokenizer must be provided for text encoding.")
113
+ full_text = " ".join(context_texts)
114
+ token_ids = self.tokenizer(full_text)
115
+ token_ids = token_ids[-self.max_context_length:]
116
+ input_tensor = torch.tensor([token_ids], dtype=torch.long, device=self.device)
117
+ return input_tensor
118
+
119
+ @torch.no_grad()
120
+ def _generate_response(self, input_tensor: torch.Tensor) -> str:
121
+ """
122
+ Generates text response from model given input tokens.
123
+
124
+ Args:
125
+ input_tensor: Tensor of token ids shape (1, seq_len).
126
+
127
+ Returns:
128
+ Generated string response.
129
+ """
130
+ outputs = self.model(input_tensor) # Expected output shape (1, seq_len, vocab_size)
131
+ logits = outputs[0, -self.max_response_length:, :] # Take last tokens logits
132
+ probabilities = F.softmax(logits, dim=-1)
133
+ token_ids = torch.multinomial(probabilities, num_samples=1).squeeze(-1).cpu().tolist()
134
+
135
+ if self.tokenizer and hasattr(self.tokenizer, "decode"):
136
+ response = self.tokenizer.decode(token_ids)
137
+ else:
138
+ # Fallback: Map token ids to chars mod 256 (dummy)
139
+ response = "".join([chr(t % 256) for t in token_ids])
140
+ return response
141
+
142
+ def _inference_loop(self):
143
+ """
144
+ Background thread to process inputs, maintain context, and generate outputs.
145
+ """
146
+ buffer = []
147
+ last_input_time = time.time()
148
+
149
+ while not self._stop_event.is_set():
150
+ try:
151
+ # Wait for input or timeout
152
+ timed_out = False
153
+ while True:
154
+ try:
155
+ inp = self.input_queue.get(timeout=0.1)
156
+ buffer.append(inp)
157
+ last_input_time = time.time()
158
+ except queue.Empty:
159
+ if time.time() - last_input_time > self.streaming_timeout:
160
+ timed_out = True
161
+ break
162
+
163
+ if len(buffer) == 0 and not timed_out:
164
+ continue
165
+
166
+ if timed_out or len(buffer) >= self.batch_size:
167
+ with self.lock:
168
+ # Update running context history with new buffer inputs
169
+ self.context_history.extend(buffer)
170
+ # Restrict context history length (simple sliding window)
171
+ if len(self.context_history) > 20:
172
+ self.context_history = self.context_history[-20:]
173
+ cur_context = self.context_history.copy()
174
+ buffer.clear()
175
+
176
+ # Encode context and generate response
177
+ input_tensor = self._encode_context(cur_context)
178
+ response = self._generate_response(input_tensor)
179
+
180
+ # Append response to context history
181
+ with self.lock:
182
+ self.context_history.append(response)
183
+
184
+ self.output_queue.put(response)
185
+
186
+ except Exception as e:
187
+ print(f"[Inference] Exception in inference loop: {e}")
188
+
189
+ print("[Inference] Stopped inference loop.")
190
+
191
+ if __name__ == "__main__":
192
+ # Minimal example using dummy tokenizer and dummy model for demonstration.
193
+
194
+ class DummyTokenizer:
195
+ def __call__(self, text):
196
+ # Simple char to token id mapping (mod 100 + 1)
197
+ return [ord(c) % 100 + 1 for c in text]
198
+ def decode(self, token_ids):
199
+ return "".join(chr((tid - 1) % 100 + 32) for tid in token_ids)
200
+
201
+ class DummyModel(torch.nn.Module):
202
+ def __init__(self):
203
+ super().__init__()
204
+ self.vocab_size = 128
205
+ def forward(self, x):
206
+ batch_size, seq_len = x.shape
207
+ # Return random logits tensor: (batch, seq_len, vocab_size)
208
+ logits = torch.randn(batch_size, seq_len, self.vocab_size)
209
+ return logits
210
+
211
+ tokenizer = DummyTokenizer()
212
+ model = DummyModel()
213
+
214
+ inference_engine = StreamingInference(model=model, tokenizer=tokenizer, max_context_length=50)
215
+ inference_engine.start()
216
+
217
+ test_inputs = [
218
+ "Hello, OmniCoreX! ",
219
+ "How are you today? ",
220
+ "Generate a super intelligent response."
221
+ ]
222
+
223
+ for inp in test_inputs:
224
+ print(f">> Input: {inp.strip()}")
225
+ inference_engine.submit_input(inp)
226
+ time.sleep(0.5)
227
+ output = inference_engine.get_response(timeout=5.0)
228
+ if output:
229
+ print(f"<< Response: {output}")
230
+
231
+ inference_engine.stop()