Trouter-Library commited on
Commit
8f40540
·
verified ·
1 Parent(s): 1914281

Create setup_helion.py

Browse files
Files changed (1) hide show
  1. setup_helion.py +516 -0
setup_helion.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-OSC Easy Setup & Usage Script
3
+ One-file solution for setting up and using Helion-OSC model
4
+
5
+ This script handles:
6
+ - Automatic dependency installation
7
+ - Model loading from HuggingFace Spaces
8
+ - GPU/CPU detection
9
+ - Memory optimization
10
+ - Simple inference interface
11
+ - Interactive mode
12
+
13
+ Usage:
14
+ python setup_helion.py --setup # First time setup
15
+ python setup_helion.py --chat # Interactive chat
16
+ python setup_helion.py --generate "your prompt here"
17
+ """
18
+
19
+ import subprocess
20
+ import sys
21
+ import os
22
+ import logging
23
+ from pathlib import Path
24
+
25
+ logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def install_dependencies():
30
+ """Install required dependencies"""
31
+ logger.info("Installing dependencies...")
32
+
33
+ dependencies = [
34
+ "torch>=2.0.0",
35
+ "transformers>=4.40.0",
36
+ "accelerate>=0.25.0",
37
+ "sentencepiece>=0.1.99",
38
+ "safetensors>=0.4.0",
39
+ "bitsandbytes>=0.41.0",
40
+ "huggingface-hub>=0.19.0"
41
+ ]
42
+
43
+ for dep in dependencies:
44
+ logger.info(f"Installing {dep}...")
45
+ try:
46
+ subprocess.check_call([sys.executable, "-m", "pip", "install", dep, "-q"])
47
+ except subprocess.CalledProcessError as e:
48
+ logger.warning(f"Failed to install {dep}: {e}")
49
+
50
+ logger.info("✓ Dependencies installed")
51
+
52
+
53
+ def check_dependencies():
54
+ """Check if dependencies are installed"""
55
+ required = {
56
+ "torch": "torch",
57
+ "transformers": "transformers",
58
+ "accelerate": "accelerate",
59
+ }
60
+
61
+ missing = []
62
+ for name, import_name in required.items():
63
+ try:
64
+ __import__(import_name)
65
+ except ImportError:
66
+ missing.append(name)
67
+
68
+ return missing
69
+
70
+
71
+ class HelionOSCEasy:
72
+ """Easy-to-use wrapper for Helion-OSC model"""
73
+
74
+ def __init__(
75
+ self,
76
+ model_name: str = "DeepXR/Helion-OSC",
77
+ device: str = "auto",
78
+ use_8bit: bool = False,
79
+ use_4bit: bool = False,
80
+ trust_remote_code: bool = True
81
+ ):
82
+ """
83
+ Initialize Helion-OSC with automatic configuration
84
+
85
+ Args:
86
+ model_name: Model identifier on HuggingFace
87
+ device: Device to use ("auto", "cuda", "cpu")
88
+ use_8bit: Use 8-bit quantization (saves memory)
89
+ use_4bit: Use 4-bit quantization (saves more memory)
90
+ trust_remote_code: Trust remote code from model
91
+ """
92
+ logger.info("="*80)
93
+ logger.info("HELION-OSC EASY SETUP")
94
+ logger.info("="*80)
95
+
96
+ # Import here after dependency check
97
+ import torch
98
+ from transformers import AutoTokenizer, AutoModelForCausalLM
99
+
100
+ self.model_name = model_name
101
+ self.torch = torch
102
+ self.AutoTokenizer = AutoTokenizer
103
+ self.AutoModelForCausalLM = AutoModelForCausalLM
104
+
105
+ # Detect device
106
+ if device == "auto":
107
+ if torch.cuda.is_available():
108
+ self.device = "cuda"
109
+ logger.info(f"✓ GPU detected: {torch.cuda.get_device_name(0)}")
110
+ logger.info(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
111
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
112
+ self.device = "mps"
113
+ logger.info("✓ Apple Silicon (MPS) detected")
114
+ else:
115
+ self.device = "cpu"
116
+ logger.info("⚠ No GPU detected, using CPU (will be slower)")
117
+ else:
118
+ self.device = device
119
+
120
+ # Set quantization
121
+ self.use_8bit = use_8bit
122
+ self.use_4bit = use_4bit
123
+
124
+ if self.use_4bit:
125
+ logger.info("Using 4-bit quantization (lowest memory)")
126
+ elif self.use_8bit:
127
+ logger.info("Using 8-bit quantization (reduced memory)")
128
+
129
+ # Check available memory
130
+ self._check_memory()
131
+
132
+ # Load model
133
+ logger.info(f"\nLoading model: {model_name}")
134
+ logger.info("This may take a few minutes on first run...")
135
+
136
+ try:
137
+ self._load_model(trust_remote_code)
138
+ logger.info("✓ Model loaded successfully!")
139
+ self._print_capabilities()
140
+ except Exception as e:
141
+ logger.error(f"Failed to load model: {e}")
142
+ logger.info("\nTroubleshooting tips:")
143
+ logger.info("1. Try with --use-4bit for lower memory usage")
144
+ logger.info("2. Make sure you have enough RAM/VRAM")
145
+ logger.info("3. Check internet connection for downloading")
146
+ raise
147
+
148
+ def _check_memory(self):
149
+ """Check available memory"""
150
+ try:
151
+ import psutil
152
+ ram_gb = psutil.virtual_memory().total / 1e9
153
+ ram_available = psutil.virtual_memory().available / 1e9
154
+
155
+ logger.info(f"\nSystem Memory:")
156
+ logger.info(f" Total RAM: {ram_gb:.1f} GB")
157
+ logger.info(f" Available: {ram_available:.1f} GB")
158
+
159
+ if self.device == "cuda":
160
+ gpu_mem = self.torch.cuda.get_device_properties(0).total_memory / 1e9
161
+ logger.info(f" GPU VRAM: {gpu_mem:.1f} GB")
162
+
163
+ if gpu_mem < 8 and not (self.use_4bit or self.use_8bit):
164
+ logger.warning(" ⚠ Low VRAM detected. Consider using --use-4bit")
165
+
166
+ elif ram_available < 16 and not (self.use_4bit or self.use_8bit):
167
+ logger.warning(" ⚠ Low RAM detected. Consider using --use-4bit")
168
+ except:
169
+ pass
170
+
171
+ def _load_model(self, trust_remote_code: bool):
172
+ """Load tokenizer and model"""
173
+ # Load tokenizer
174
+ logger.info("Loading tokenizer...")
175
+ self.tokenizer = self.AutoTokenizer.from_pretrained(
176
+ self.model_name,
177
+ trust_remote_code=trust_remote_code
178
+ )
179
+
180
+ if self.tokenizer.pad_token is None:
181
+ self.tokenizer.pad_token = self.tokenizer.eos_token
182
+
183
+ # Configure model loading
184
+ model_kwargs = {
185
+ "trust_remote_code": trust_remote_code,
186
+ "low_cpu_mem_usage": True
187
+ }
188
+
189
+ if self.use_4bit:
190
+ from transformers import BitsAndBytesConfig
191
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
192
+ load_in_4bit=True,
193
+ bnb_4bit_compute_dtype=self.torch.bfloat16,
194
+ bnb_4bit_use_double_quant=True,
195
+ bnb_4bit_quant_type="nf4"
196
+ )
197
+ elif self.use_8bit:
198
+ model_kwargs["load_in_8bit"] = True
199
+ else:
200
+ if self.device == "cuda":
201
+ model_kwargs["torch_dtype"] = self.torch.bfloat16
202
+ model_kwargs["device_map"] = "auto"
203
+ else:
204
+ model_kwargs["torch_dtype"] = self.torch.float32
205
+
206
+ # Load model
207
+ logger.info("Loading model weights...")
208
+ self.model = self.AutoModelForCausalLM.from_pretrained(
209
+ self.model_name,
210
+ **model_kwargs
211
+ )
212
+
213
+ if self.device == "cpu" and not (self.use_4bit or self.use_8bit):
214
+ self.model = self.model.to(self.device)
215
+
216
+ self.model.eval()
217
+
218
+ def _print_capabilities(self):
219
+ """Print model capabilities"""
220
+ logger.info("\n" + "="*80)
221
+ logger.info("MODEL CAPABILITIES")
222
+ logger.info("="*80)
223
+ logger.info("✓ Code generation (Python, JavaScript, C++, Java, Rust, Go, etc.)")
224
+ logger.info("✓ Mathematical reasoning and theorem proving")
225
+ logger.info("✓ Algorithm design and optimization")
226
+ logger.info("✓ Code debugging and error fixing")
227
+ logger.info("✓ Step-by-step problem solving")
228
+ logger.info("✓ 250K+ token context length")
229
+ logger.info("="*80)
230
+
231
+ def generate(
232
+ self,
233
+ prompt: str,
234
+ max_length: int = 2048,
235
+ temperature: float = 0.7,
236
+ top_p: float = 0.95,
237
+ top_k: int = 50,
238
+ do_sample: bool = True,
239
+ verbose: bool = True
240
+ ) -> str:
241
+ """
242
+ Generate text from prompt
243
+
244
+ Args:
245
+ prompt: Input prompt
246
+ max_length: Maximum tokens to generate
247
+ temperature: Sampling temperature (higher = more creative)
248
+ top_p: Nucleus sampling parameter
249
+ top_k: Top-k sampling parameter
250
+ do_sample: Use sampling (False = greedy)
251
+ verbose: Print generation info
252
+
253
+ Returns:
254
+ Generated text
255
+ """
256
+ if verbose:
257
+ logger.info(f"\nGenerating response...")
258
+ logger.info(f"Prompt length: {len(prompt)} chars")
259
+
260
+ # Tokenize
261
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
262
+ input_length = inputs.input_ids.shape[1]
263
+
264
+ if verbose:
265
+ logger.info(f"Input tokens: {input_length}")
266
+
267
+ # Generate
268
+ with self.torch.no_grad():
269
+ outputs = self.model.generate(
270
+ **inputs,
271
+ max_length=min(max_length, 8192), # Limit for reasonable speed
272
+ temperature=temperature,
273
+ top_p=top_p,
274
+ top_k=top_k,
275
+ do_sample=do_sample,
276
+ pad_token_id=self.tokenizer.eos_token_id
277
+ )
278
+
279
+ # Decode
280
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
281
+
282
+ # Remove prompt from output
283
+ response = generated_text[len(prompt):].strip()
284
+
285
+ if verbose:
286
+ output_tokens = outputs.shape[1] - input_length
287
+ logger.info(f"Generated tokens: {output_tokens}")
288
+
289
+ return response
290
+
291
+ def chat(self, system_prompt: str = "You are Helion-OSC, a helpful AI coding assistant."):
292
+ """Interactive chat mode"""
293
+ logger.info("\n" + "="*80)
294
+ logger.info("INTERACTIVE CHAT MODE")
295
+ logger.info("="*80)
296
+ logger.info("Commands:")
297
+ logger.info(" /help - Show this help")
298
+ logger.info(" /clear - Clear conversation")
299
+ logger.info(" /settings - Change generation settings")
300
+ logger.info(" /quit - Exit chat")
301
+ logger.info("="*80)
302
+
303
+ conversation = []
304
+ settings = {
305
+ "temperature": 0.7,
306
+ "max_length": 2048,
307
+ "top_p": 0.95
308
+ }
309
+
310
+ while True:
311
+ try:
312
+ user_input = input("\n💬 You: ").strip()
313
+
314
+ if not user_input:
315
+ continue
316
+
317
+ if user_input == "/quit":
318
+ logger.info("Goodbye!")
319
+ break
320
+
321
+ elif user_input == "/help":
322
+ logger.info("\nAvailable commands:")
323
+ logger.info(" /help - Show this help")
324
+ logger.info(" /clear - Clear conversation history")
325
+ logger.info(" /settings - Adjust generation settings")
326
+ logger.info(" /quit - Exit chat")
327
+ continue
328
+
329
+ elif user_input == "/clear":
330
+ conversation = []
331
+ logger.info("✓ Conversation cleared")
332
+ continue
333
+
334
+ elif user_input == "/settings":
335
+ logger.info("\nCurrent settings:")
336
+ logger.info(f" Temperature: {settings['temperature']}")
337
+ logger.info(f" Max length: {settings['max_length']}")
338
+ logger.info(f" Top-p: {settings['top_p']}")
339
+
340
+ temp = input("New temperature (0.0-2.0, press Enter to skip): ").strip()
341
+ if temp:
342
+ settings['temperature'] = float(temp)
343
+
344
+ max_len = input("New max length (press Enter to skip): ").strip()
345
+ if max_len:
346
+ settings['max_length'] = int(max_len)
347
+
348
+ logger.info("✓ Settings updated")
349
+ continue
350
+
351
+ # Build prompt with conversation history
352
+ conversation.append({"role": "user", "content": user_input})
353
+
354
+ prompt = system_prompt + "\n\n"
355
+ for msg in conversation:
356
+ if msg["role"] == "user":
357
+ prompt += f"User: {msg['content']}\n\n"
358
+ else:
359
+ prompt += f"Assistant: {msg['content']}\n\n"
360
+ prompt += "Assistant:"
361
+
362
+ # Generate response
363
+ response = self.generate(
364
+ prompt,
365
+ max_length=settings['max_length'],
366
+ temperature=settings['temperature'],
367
+ top_p=settings['top_p'],
368
+ verbose=False
369
+ )
370
+
371
+ conversation.append({"role": "assistant", "content": response})
372
+
373
+ print(f"\n🤖 Helion: {response}")
374
+
375
+ except KeyboardInterrupt:
376
+ logger.info("\n\nGoodbye!")
377
+ break
378
+ except Exception as e:
379
+ logger.error(f"Error: {e}")
380
+
381
+
382
+ def main():
383
+ """Main CLI interface"""
384
+ import argparse
385
+
386
+ parser = argparse.ArgumentParser(
387
+ description="Helion-OSC Easy Setup & Usage",
388
+ formatter_class=argparse.RawDescriptionHelpFormatter,
389
+ epilog="""
390
+ Examples:
391
+ # First time setup
392
+ python setup_helion.py --setup
393
+
394
+ # Interactive chat
395
+ python setup_helion.py --chat
396
+
397
+ # Generate from prompt
398
+ python setup_helion.py --generate "Write a Python function to sort a list"
399
+
400
+ # Use 4-bit quantization (low memory)
401
+ python setup_helion.py --chat --use-4bit
402
+
403
+ # Generate with custom settings
404
+ python setup_helion.py --generate "Solve x^2 = 16" --temperature 0.3 --max-length 1024
405
+ """
406
+ )
407
+
408
+ parser.add_argument(
409
+ "--setup",
410
+ action="store_true",
411
+ help="Install dependencies and set up model"
412
+ )
413
+
414
+ parser.add_argument(
415
+ "--chat",
416
+ action="store_true",
417
+ help="Start interactive chat mode"
418
+ )
419
+
420
+ parser.add_argument(
421
+ "--generate",
422
+ type=str,
423
+ help="Generate response for a prompt"
424
+ )
425
+
426
+ parser.add_argument(
427
+ "--model",
428
+ type=str,
429
+ default="DeepXR/Helion-OSC",
430
+ help="Model name on HuggingFace"
431
+ )
432
+
433
+ parser.add_argument(
434
+ "--use-4bit",
435
+ action="store_true",
436
+ help="Use 4-bit quantization (lowest memory)"
437
+ )
438
+
439
+ parser.add_argument(
440
+ "--use-8bit",
441
+ action="store_true",
442
+ help="Use 8-bit quantization"
443
+ )
444
+
445
+ parser.add_argument(
446
+ "--temperature",
447
+ type=float,
448
+ default=0.7,
449
+ help="Sampling temperature (default: 0.7)"
450
+ )
451
+
452
+ parser.add_argument(
453
+ "--max-length",
454
+ type=int,
455
+ default=2048,
456
+ help="Maximum generation length (default: 2048)"
457
+ )
458
+
459
+ parser.add_argument(
460
+ "--top-p",
461
+ type=float,
462
+ default=0.95,
463
+ help="Top-p sampling (default: 0.95)"
464
+ )
465
+
466
+ args = parser.parse_args()
467
+
468
+ # Setup mode
469
+ if args.setup:
470
+ logger.info("Setting up Helion-OSC...")
471
+ install_dependencies()
472
+ logger.info("\n✓ Setup complete!")
473
+ logger.info("\nNext steps:")
474
+ logger.info(" python setup_helion.py --chat")
475
+ return
476
+
477
+ # Check dependencies
478
+ missing = check_dependencies()
479
+ if missing:
480
+ logger.error(f"Missing dependencies: {', '.join(missing)}")
481
+ logger.info("Run: python setup_helion.py --setup")
482
+ return
483
+
484
+ # Initialize model
485
+ try:
486
+ helion = HelionOSCEasy(
487
+ model_name=args.model,
488
+ use_8bit=args.use_8bit,
489
+ use_4bit=args.use_4bit
490
+ )
491
+ except Exception as e:
492
+ logger.error(f"Failed to initialize model: {e}")
493
+ return
494
+
495
+ # Chat mode
496
+ if args.chat:
497
+ helion.chat()
498
+
499
+ # Generate mode
500
+ elif args.generate:
501
+ response = helion.generate(
502
+ args.generate,
503
+ max_length=args.max_length,
504
+ temperature=args.temperature,
505
+ top_p=args.top_p
506
+ )
507
+ print(f"\n{response}\n")
508
+
509
+ # Default: show help
510
+ else:
511
+ logger.info("No action specified. Use --chat or --generate")
512
+ logger.info("Run with --help for more options")
513
+
514
+
515
+ if __name__ == "__main__":
516
+ main()