| """ |
| Helion-V1.5-XL Inference Script |
| Supports multiple inference modes and optimization techniques |
| """ |
|
|
| import torch |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| BitsAndBytesConfig, |
| GenerationConfig |
| ) |
| from typing import Optional, Dict, Any, List |
| import argparse |
| import json |
| import time |
|
|
|
|
| class HelionInference: |
| """Inference wrapper for Helion-V1.5-XL""" |
| |
| def __init__( |
| self, |
| model_name: str = "DeepXR/Helion-V1.5-XL", |
| load_in_4bit: bool = False, |
| load_in_8bit: bool = False, |
| device_map: str = "auto", |
| torch_dtype: str = "bfloat16" |
| ): |
| """ |
| Initialize the model and tokenizer |
| |
| Args: |
| model_name: HuggingFace model identifier |
| load_in_4bit: Enable 4-bit quantization |
| load_in_8bit: Enable 8-bit quantization |
| device_map: Device mapping strategy |
| torch_dtype: PyTorch dtype for model weights |
| """ |
| self.model_name = model_name |
| print(f"Loading model: {model_name}") |
| |
| |
| dtype_map = { |
| "bfloat16": torch.bfloat16, |
| "float16": torch.float16, |
| "float32": torch.float32 |
| } |
| torch_dtype = dtype_map.get(torch_dtype, torch.bfloat16) |
| |
| |
| quantization_config = None |
| if load_in_4bit: |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch_dtype, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4" |
| ) |
| elif load_in_8bit: |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| trust_remote_code=True |
| ) |
| |
| |
| model_kwargs = { |
| "device_map": device_map, |
| "trust_remote_code": True, |
| } |
| |
| if quantization_config: |
| model_kwargs["quantization_config"] = quantization_config |
| else: |
| model_kwargs["torch_dtype"] = torch_dtype |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| **model_kwargs |
| ) |
| |
| self.model.eval() |
| print("Model loaded successfully!") |
| |
| def generate( |
| self, |
| prompt: str, |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| top_k: int = 50, |
| repetition_penalty: float = 1.1, |
| do_sample: bool = True, |
| num_return_sequences: int = 1, |
| **kwargs |
| ) -> List[str]: |
| """ |
| Generate text from a prompt |
| |
| Args: |
| prompt: Input text prompt |
| max_new_tokens: Maximum number of tokens to generate |
| temperature: Sampling temperature (0.0 to 2.0) |
| top_p: Nucleus sampling threshold |
| top_k: Top-k sampling threshold |
| repetition_penalty: Penalty for repetition |
| do_sample: Whether to use sampling |
| num_return_sequences: Number of sequences to generate |
| |
| Returns: |
| List of generated text strings |
| """ |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| |
| generation_config = GenerationConfig( |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| repetition_penalty=repetition_penalty, |
| do_sample=do_sample, |
| num_return_sequences=num_return_sequences, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=self.tokenizer.eos_token_id, |
| **kwargs |
| ) |
| |
| start_time = time.time() |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| generation_config=generation_config |
| ) |
| |
| generation_time = time.time() - start_time |
| |
| |
| responses = [] |
| for output in outputs: |
| response = self.tokenizer.decode(output, skip_special_tokens=True) |
| |
| response = response[len(prompt):].strip() |
| responses.append(response) |
| |
| |
| total_tokens = sum(len(output) for output in outputs) |
| tokens_per_sec = total_tokens / generation_time |
| |
| print(f"\nGeneration Stats:") |
| print(f" Time: {generation_time:.2f}s") |
| print(f" Tokens/sec: {tokens_per_sec:.2f}") |
| |
| return responses |
| |
| def chat( |
| self, |
| messages: List[Dict[str, str]], |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| **kwargs |
| ) -> str: |
| """ |
| Generate response in chat format |
| |
| Args: |
| messages: List of message dicts with 'role' and 'content' |
| max_new_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| |
| Returns: |
| Generated response string |
| """ |
| |
| prompt = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| responses = self.generate( |
| prompt, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| **kwargs |
| ) |
| |
| return responses[0] |
| |
| def batch_generate( |
| self, |
| prompts: List[str], |
| max_new_tokens: int = 512, |
| **kwargs |
| ) -> List[str]: |
| """ |
| Generate responses for multiple prompts in batch |
| |
| Args: |
| prompts: List of input prompts |
| max_new_tokens: Maximum tokens per generation |
| |
| Returns: |
| List of generated responses |
| """ |
| inputs = self.tokenizer( |
| prompts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True |
| ).to(self.model.device) |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| **kwargs |
| ) |
| |
| responses = [] |
| for i, output in enumerate(outputs): |
| response = self.tokenizer.decode(output, skip_special_tokens=True) |
| |
| response = response[len(prompts[i]):].strip() |
| responses.append(response) |
| |
| return responses |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Helion-V1.5-XL Inference") |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="DeepXR/Helion-V1.5-XL", |
| help="Model name or path" |
| ) |
| parser.add_argument( |
| "--prompt", |
| type=str, |
| required=True, |
| help="Input prompt" |
| ) |
| parser.add_argument( |
| "--max-tokens", |
| type=int, |
| default=512, |
| help="Maximum tokens to generate" |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=0.7, |
| help="Sampling temperature" |
| ) |
| parser.add_argument( |
| "--top-p", |
| type=float, |
| default=0.9, |
| help="Nucleus sampling threshold" |
| ) |
| parser.add_argument( |
| "--load-in-4bit", |
| action="store_true", |
| help="Load model in 4-bit quantization" |
| ) |
| parser.add_argument( |
| "--load-in-8bit", |
| action="store_true", |
| help="Load model in 8-bit quantization" |
| ) |
| parser.add_argument( |
| "--chat-mode", |
| action="store_true", |
| help="Use chat format" |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| inference = HelionInference( |
| model_name=args.model, |
| load_in_4bit=args.load_in_4bit, |
| load_in_8bit=args.load_in_8bit |
| ) |
| |
| |
| if args.chat_mode: |
| messages = [ |
| {"role": "user", "content": args.prompt} |
| ] |
| response = inference.chat( |
| messages, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p |
| ) |
| else: |
| responses = inference.generate( |
| args.prompt, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p |
| ) |
| response = responses[0] |
| |
| print("\n" + "="*80) |
| print("PROMPT:") |
| print("="*80) |
| print(args.prompt) |
| print("\n" + "="*80) |
| print("RESPONSE:") |
| print("="*80) |
| print(response) |
| print("="*80) |
|
|
|
|
| if __name__ == "__main__": |
| main() |