""" Helion-V1.5 Inference Script Simple interface for using the model """ import torch import logging from typing import List, Dict, Optional logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class HelionV15: """Easy-to-use interface for Helion-V1.5.""" def __init__( self, model_name: str = "DeepXR/Helion-V1.5", device: str = "auto", load_in_4bit: bool = False ): """ Initialize Helion-V1.5 model. Args: model_name: Model name or path device: Device to load model on load_in_4bit: Use 4-bit quantization """ from transformers import AutoTokenizer, AutoModelForCausalLM logger.info(f"Loading Helion-V1.5: {model_name}") self.tokenizer = AutoTokenizer.from_pretrained(model_name) load_kwargs = { "device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True } if load_in_4bit: from transformers import BitsAndBytesConfig load_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 ) self.model = AutoModelForCausalLM.from_pretrained( model_name, **load_kwargs ) self.model.eval() logger.info("Model loaded successfully") def chat( self, messages: List[Dict[str, str]], max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, do_sample: bool = True ) -> str: """ Generate response from messages. Args: messages: List of message dicts with 'role' and 'content' max_new_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Nucleus sampling parameter do_sample: Whether to use sampling Returns: Generated response text """ # Apply chat template input_ids = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(self.model.device) # Generate with torch.no_grad(): output = self.model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decode response response = self.tokenizer.decode( output[0][input_ids.shape[1]:], skip_special_tokens=True ) return response.strip() def generate( self, prompt: str, max_new_tokens: int = 512, **kwargs ) -> str: """ Generate text from a simple prompt. Args: prompt: Input text max_new_tokens: Maximum tokens to generate **kwargs: Additional generation parameters Returns: Generated text """ messages = [{"role": "user", "content": prompt}] return self.chat(messages, max_new_tokens=max_new_tokens, **kwargs) def interactive(self): """Start interactive chat session.""" print("\n" + "="*60) print("Helion-V1.5 Interactive Chat") print("Type 'quit' or 'exit' to end") print("="*60 + "\n") conversation = [] while True: user_input = input("You: ").strip() if user_input.lower() in ['quit', 'exit']: print("Goodbye!") break if not user_input: continue conversation.append({"role": "user", "content": user_input}) try: response = self.chat(conversation) print(f"Helion: {response}\n") conversation.append({"role": "assistant", "content": response}) except Exception as e: print(f"Error: {e}") conversation.pop() # Remove failed message def main(): """Main CLI interface.""" import argparse parser = argparse.ArgumentParser(description="Helion-V1.5 Inference") parser.add_argument("--model", default="DeepXR/Helion-V1.5") parser.add_argument("--device", default="auto") parser.add_argument("--4bit", action="store_true", help="Use 4-bit quantization") parser.add_argument("--interactive", action="store_true", help="Interactive chat") parser.add_argument("--prompt", type=str, help="Single prompt") parser.add_argument("--max-tokens", type=int, default=512) parser.add_argument("--temperature", type=float, default=0.7) args = parser.parse_args() # Initialize model helion = HelionV15( model_name=args.model, device=args.device, load_in_4bit=args.__dict__.get('4bit', False) ) if args.interactive: helion.interactive() elif args.prompt: response = helion.generate( args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature ) print(f"\nResponse:\n{response}") else: print("Use --interactive or --prompt") if __name__ == "__main__": main()