Trouter-Library commited on
Commit
fdcf9f5
·
verified ·
1 Parent(s): 7e21c95

Create multi_model_inference.py

Browse files
Files changed (1) hide show
  1. multi_model_inference.py +501 -0
multi_model_inference.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Model Inference System for Helion-OSC
3
+ Supports 4 different model variants for specialized tasks
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from typing import Optional, Dict, Any, List
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ModelType(Enum):
18
+ """Available model types"""
19
+ BASE = "base" # General purpose coding
20
+ MATH = "math" # Mathematical reasoning
21
+ ALGORITHM = "algorithm" # Algorithm design & optimization
22
+ DEBUG = "debug" # Code debugging & fixing
23
+
24
+
25
+ @dataclass
26
+ class ModelConfig:
27
+ """Configuration for each model variant"""
28
+ name: str
29
+ model_path: str
30
+ description: str
31
+ default_temperature: float
32
+ default_max_length: int
33
+ default_top_p: float
34
+
35
+
36
+ class MultiModelInference:
37
+ """
38
+ Multi-model inference system with 4 specialized models
39
+ """
40
+
41
+ # Model configurations
42
+ MODELS = {
43
+ ModelType.BASE: ModelConfig(
44
+ name="Helion-OSC Base",
45
+ model_path="DeepXR/Helion-OSC",
46
+ description="General purpose code generation and completion",
47
+ default_temperature=0.7,
48
+ default_max_length=2048,
49
+ default_top_p=0.95
50
+ ),
51
+ ModelType.MATH: ModelConfig(
52
+ name="Helion-OSC Math",
53
+ model_path="DeepXR/Helion-OSC", # In production, use specialized variant
54
+ description="Mathematical reasoning and theorem proving",
55
+ default_temperature=0.3,
56
+ default_max_length=2048,
57
+ default_top_p=0.9
58
+ ),
59
+ ModelType.ALGORITHM: ModelConfig(
60
+ name="Helion-OSC Algorithm",
61
+ model_path="DeepXR/Helion-OSC", # In production, use specialized variant
62
+ description="Algorithm design and optimization",
63
+ default_temperature=0.5,
64
+ default_max_length=3072,
65
+ default_top_p=0.93
66
+ ),
67
+ ModelType.DEBUG: ModelConfig(
68
+ name="Helion-OSC Debug",
69
+ model_path="DeepXR/Helion-OSC", # In production, use specialized variant
70
+ description="Code debugging and error fixing",
71
+ default_temperature=0.4,
72
+ default_max_length=2048,
73
+ default_top_p=0.88
74
+ )
75
+ }
76
+
77
+ def __init__(
78
+ self,
79
+ device: Optional[str] = None,
80
+ load_all_models: bool = False,
81
+ use_8bit: bool = False
82
+ ):
83
+ """
84
+ Initialize multi-model inference system
85
+
86
+ Args:
87
+ device: Device to use (cuda/cpu)
88
+ load_all_models: Load all models at startup (uses more memory)
89
+ use_8bit: Use 8-bit quantization for memory efficiency
90
+ """
91
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
92
+ self.use_8bit = use_8bit
93
+ self.loaded_models: Dict[ModelType, Any] = {}
94
+ self.tokenizers: Dict[ModelType, Any] = {}
95
+
96
+ logger.info(f"Initializing Multi-Model Inference System on {self.device}")
97
+
98
+ if load_all_models:
99
+ logger.info("Loading all models at startup...")
100
+ for model_type in ModelType:
101
+ self._load_model(model_type)
102
+ else:
103
+ logger.info("Models will be loaded on-demand")
104
+
105
+ def _load_model(self, model_type: ModelType):
106
+ """Load a specific model variant"""
107
+ if model_type in self.loaded_models:
108
+ logger.info(f"{model_type.value} model already loaded")
109
+ return
110
+
111
+ config = self.MODELS[model_type]
112
+ logger.info(f"Loading {config.name}...")
113
+
114
+ try:
115
+ # Load tokenizer
116
+ tokenizer = AutoTokenizer.from_pretrained(
117
+ config.model_path,
118
+ trust_remote_code=True
119
+ )
120
+ if tokenizer.pad_token is None:
121
+ tokenizer.pad_token = tokenizer.eos_token
122
+
123
+ # Load model
124
+ model_kwargs = {
125
+ "trust_remote_code": True,
126
+ "low_cpu_mem_usage": True
127
+ }
128
+
129
+ if self.use_8bit:
130
+ model_kwargs["load_in_8bit"] = True
131
+ elif self.device == "cuda":
132
+ model_kwargs["torch_dtype"] = torch.bfloat16
133
+ model_kwargs["device_map"] = "auto"
134
+ else:
135
+ model_kwargs["torch_dtype"] = torch.float32
136
+
137
+ model = AutoModelForCausalLM.from_pretrained(
138
+ config.model_path,
139
+ **model_kwargs
140
+ )
141
+
142
+ if self.device == "cpu" and not self.use_8bit:
143
+ model = model.to(self.device)
144
+
145
+ model.eval()
146
+
147
+ self.loaded_models[model_type] = model
148
+ self.tokenizers[model_type] = tokenizer
149
+
150
+ logger.info(f"✓ {config.name} loaded successfully")
151
+
152
+ except Exception as e:
153
+ logger.error(f"Failed to load {config.name}: {e}")
154
+ raise
155
+
156
+ def _ensure_model_loaded(self, model_type: ModelType):
157
+ """Ensure a model is loaded before use"""
158
+ if model_type not in self.loaded_models:
159
+ self._load_model(model_type)
160
+
161
+ def generate(
162
+ self,
163
+ prompt: str,
164
+ model_type: ModelType = ModelType.BASE,
165
+ max_length: Optional[int] = None,
166
+ temperature: Optional[float] = None,
167
+ top_p: Optional[float] = None,
168
+ top_k: int = 50,
169
+ do_sample: Optional[bool] = None,
170
+ num_return_sequences: int = 1,
171
+ **kwargs
172
+ ) -> str:
173
+ """
174
+ Generate text using specified model
175
+
176
+ Args:
177
+ prompt: Input prompt
178
+ model_type: Which model to use
179
+ max_length: Maximum generation length
180
+ temperature: Sampling temperature
181
+ top_p: Nucleus sampling parameter
182
+ top_k: Top-k sampling parameter
183
+ do_sample: Whether to use sampling
184
+ num_return_sequences: Number of sequences to generate
185
+ **kwargs: Additional generation parameters
186
+
187
+ Returns:
188
+ Generated text
189
+ """
190
+ self._ensure_model_loaded(model_type)
191
+
192
+ config = self.MODELS[model_type]
193
+ model = self.loaded_models[model_type]
194
+ tokenizer = self.tokenizers[model_type]
195
+
196
+ # Use defaults if not specified
197
+ max_length = max_length or config.default_max_length
198
+ temperature = temperature or config.default_temperature
199
+ top_p = top_p or config.default_top_p
200
+ do_sample = do_sample if do_sample is not None else (temperature > 0)
201
+
202
+ logger.info(f"Generating with {config.name}...")
203
+
204
+ # Tokenize
205
+ inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
206
+
207
+ # Generate
208
+ with torch.no_grad():
209
+ outputs = model.generate(
210
+ **inputs,
211
+ max_length=max_length,
212
+ temperature=temperature,
213
+ top_p=top_p,
214
+ top_k=top_k,
215
+ do_sample=do_sample,
216
+ num_return_sequences=num_return_sequences,
217
+ pad_token_id=tokenizer.eos_token_id,
218
+ **kwargs
219
+ )
220
+
221
+ # Decode
222
+ if num_return_sequences == 1:
223
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
224
+ return generated[len(prompt):].strip()
225
+ else:
226
+ results = []
227
+ for output in outputs:
228
+ generated = tokenizer.decode(output, skip_special_tokens=True)
229
+ results.append(generated[len(prompt):].strip())
230
+ return results
231
+
232
+ def code_generation(
233
+ self,
234
+ prompt: str,
235
+ language: Optional[str] = None,
236
+ **kwargs
237
+ ) -> str:
238
+ """Generate code using base model"""
239
+ if language:
240
+ prompt = f"Language: {language}\n\n{prompt}"
241
+
242
+ return self.generate(
243
+ prompt,
244
+ model_type=ModelType.BASE,
245
+ **kwargs
246
+ )
247
+
248
+ def solve_math(
249
+ self,
250
+ problem: str,
251
+ show_steps: bool = True,
252
+ **kwargs
253
+ ) -> str:
254
+ """Solve mathematical problem using math model"""
255
+ if show_steps:
256
+ prompt = f"Solve the following problem step by step:\n\n{problem}\n\nSolution:"
257
+ else:
258
+ prompt = f"Solve: {problem}\n\nAnswer:"
259
+
260
+ return self.generate(
261
+ prompt,
262
+ model_type=ModelType.MATH,
263
+ **kwargs
264
+ )
265
+
266
+ def design_algorithm(
267
+ self,
268
+ problem: str,
269
+ include_complexity: bool = True,
270
+ **kwargs
271
+ ) -> str:
272
+ """Design algorithm using algorithm model"""
273
+ prompt = f"Design an efficient algorithm for:\n\n{problem}"
274
+ if include_complexity:
275
+ prompt += "\n\nInclude time and space complexity analysis."
276
+
277
+ return self.generate(
278
+ prompt,
279
+ model_type=ModelType.ALGORITHM,
280
+ **kwargs
281
+ )
282
+
283
+ def debug_code(
284
+ self,
285
+ code: str,
286
+ error_message: Optional[str] = None,
287
+ language: str = "python",
288
+ **kwargs
289
+ ) -> str:
290
+ """Debug code using debug model"""
291
+ prompt = f"Debug the following {language} code:\n\n```{language}\n{code}\n```"
292
+ if error_message:
293
+ prompt += f"\n\nError: {error_message}"
294
+ prompt += "\n\nProvide analysis and fixed code:"
295
+
296
+ return self.generate(
297
+ prompt,
298
+ model_type=ModelType.DEBUG,
299
+ **kwargs
300
+ )
301
+
302
+ def get_loaded_models(self) -> List[str]:
303
+ """Get list of currently loaded models"""
304
+ return [self.MODELS[mt].name for mt in self.loaded_models.keys()]
305
+
306
+ def unload_model(self, model_type: ModelType):
307
+ """Unload a model to free memory"""
308
+ if model_type in self.loaded_models:
309
+ del self.loaded_models[model_type]
310
+ del self.tokenizers[model_type]
311
+ if torch.cuda.is_available():
312
+ torch.cuda.empty_cache()
313
+ logger.info(f"Unloaded {self.MODELS[model_type].name}")
314
+
315
+ def unload_all(self):
316
+ """Unload all models"""
317
+ for model_type in list(self.loaded_models.keys()):
318
+ self.unload_model(model_type)
319
+ logger.info("All models unloaded")
320
+
321
+
322
+ def demonstrate_all_models():
323
+ """Demonstrate all 4 models"""
324
+ print("="*80)
325
+ print("HELION-OSC MULTI-MODEL INFERENCE DEMONSTRATION")
326
+ print("="*80)
327
+
328
+ # Initialize system (load models on-demand to save memory)
329
+ system = MultiModelInference(load_all_models=False, use_8bit=False)
330
+
331
+ # Example 1: Base Model - General Code Generation
332
+ print("\n" + "="*80)
333
+ print("MODEL 1: BASE - General Code Generation")
334
+ print("="*80)
335
+ prompt1 = "Write a Python function to check if a string is a palindrome:"
336
+ print(f"Prompt: {prompt1}")
337
+ print("\nGenerating...")
338
+ result1 = system.code_generation(prompt1, language="python", max_length=512)
339
+ print(f"\nResult:\n{result1}\n")
340
+
341
+ # Example 2: Math Model - Mathematical Reasoning
342
+ print("\n" + "="*80)
343
+ print("MODEL 2: MATH - Mathematical Reasoning")
344
+ print("="*80)
345
+ prompt2 = "Find the derivative of f(x) = 3x^4 - 2x^3 + 5x - 7"
346
+ print(f"Prompt: {prompt2}")
347
+ print("\nGenerating...")
348
+ result2 = system.solve_math(prompt2, show_steps=True, max_length=1024)
349
+ print(f"\nResult:\n{result2}\n")
350
+
351
+ # Example 3: Algorithm Model - Algorithm Design
352
+ print("\n" + "="*80)
353
+ print("MODEL 3: ALGORITHM - Algorithm Design")
354
+ print("="*80)
355
+ prompt3 = "Find the longest common subsequence of two strings"
356
+ print(f"Prompt: {prompt3}")
357
+ print("\nGenerating...")
358
+ result3 = system.design_algorithm(prompt3, include_complexity=True, max_length=2048)
359
+ print(f"\nResult:\n{result3}\n")
360
+
361
+ # Example 4: Debug Model - Code Debugging
362
+ print("\n" + "="*80)
363
+ print("MODEL 4: DEBUG - Code Debugging")
364
+ print("="*80)
365
+ buggy_code = """
366
+ def factorial(n):
367
+ if n == 0:
368
+ return 1
369
+ return n * factorial(n)
370
+ """
371
+ print(f"Buggy Code:\n{buggy_code}")
372
+ print("\nGenerating debugging analysis...")
373
+ result4 = system.debug_code(
374
+ buggy_code,
375
+ error_message="RecursionError: maximum recursion depth exceeded",
376
+ max_length=1024
377
+ )
378
+ print(f"\nResult:\n{result4}\n")
379
+
380
+ # Show loaded models
381
+ print("="*80)
382
+ print("LOADED MODELS:")
383
+ print("="*80)
384
+ for model_name in system.get_loaded_models():
385
+ print(f"✓ {model_name}")
386
+
387
+ print("\n" + "="*80)
388
+ print("DEMONSTRATION COMPLETE")
389
+ print("="*80)
390
+
391
+
392
+ def interactive_mode():
393
+ """Interactive mode for testing models"""
394
+ system = MultiModelInference(load_all_models=False)
395
+
396
+ print("\n" + "="*80)
397
+ print("HELION-OSC INTERACTIVE MODE")
398
+ print("="*80)
399
+ print("\nAvailable commands:")
400
+ print(" 1 - Generate code (Base model)")
401
+ print(" 2 - Solve math (Math model)")
402
+ print(" 3 - Design algorithm (Algorithm model)")
403
+ print(" 4 - Debug code (Debug model)")
404
+ print(" models - Show loaded models")
405
+ print(" quit - Exit")
406
+ print("="*80)
407
+
408
+ while True:
409
+ try:
410
+ command = input("\nEnter command (1-4, models, or quit): ").strip().lower()
411
+
412
+ if command == "quit":
413
+ print("Exiting...")
414
+ break
415
+
416
+ elif command == "models":
417
+ loaded = system.get_loaded_models()
418
+ if loaded:
419
+ print("\nLoaded models:")
420
+ for model in loaded:
421
+ print(f" ✓ {model}")
422
+ else:
423
+ print("\nNo models loaded yet")
424
+
425
+ elif command == "1":
426
+ prompt = input("\nEnter code generation prompt: ")
427
+ language = input("Programming language (or press Enter for Python): ").strip() or "python"
428
+ print("\nGenerating...")
429
+ result = system.code_generation(prompt, language=language)
430
+ print(f"\n{result}\n")
431
+
432
+ elif command == "2":
433
+ problem = input("\nEnter math problem: ")
434
+ print("\nSolving...")
435
+ result = system.solve_math(problem)
436
+ print(f"\n{result}\n")
437
+
438
+ elif command == "3":
439
+ problem = input("\nEnter algorithm problem: ")
440
+ print("\nDesigning algorithm...")
441
+ result = system.design_algorithm(problem)
442
+ print(f"\n{result}\n")
443
+
444
+ elif command == "4":
445
+ print("\nEnter code to debug (type 'END' on a new line when done):")
446
+ code_lines = []
447
+ while True:
448
+ line = input()
449
+ if line == "END":
450
+ break
451
+ code_lines.append(line)
452
+ code = "\n".join(code_lines)
453
+ error = input("\nError message (optional): ").strip() or None
454
+ print("\nDebugging...")
455
+ result = system.debug_code(code, error_message=error)
456
+ print(f"\n{result}\n")
457
+
458
+ else:
459
+ print("Invalid command. Please try again.")
460
+
461
+ except KeyboardInterrupt:
462
+ print("\n\nExiting...")
463
+ break
464
+ except Exception as e:
465
+ print(f"\nError: {e}")
466
+
467
+ system.unload_all()
468
+
469
+
470
+ def main():
471
+ """Main entry point"""
472
+ import argparse
473
+
474
+ parser = argparse.ArgumentParser(description="Helion-OSC Multi-Model Inference")
475
+ parser.add_argument(
476
+ "--mode",
477
+ choices=["demo", "interactive"],
478
+ default="demo",
479
+ help="Run mode: demo or interactive"
480
+ )
481
+ parser.add_argument(
482
+ "--load-all",
483
+ action="store_true",
484
+ help="Load all models at startup"
485
+ )
486
+ parser.add_argument(
487
+ "--use-8bit",
488
+ action="store_true",
489
+ help="Use 8-bit quantization"
490
+ )
491
+
492
+ args = parser.parse_args()
493
+
494
+ if args.mode == "demo":
495
+ demonstrate_all_models()
496
+ else:
497
+ interactive_mode()
498
+
499
+
500
+ if __name__ == "__main__":
501
+ main()