Helion-V1.5 / inference_v15.py
Trouter-Library's picture
Create inference_v15.py
05fe834 verified
"""
Helion-V1.5 Inference Script
Simple interface for using the model
"""
import torch
import logging
from typing import List, Dict, Optional
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class HelionV15:
"""Easy-to-use interface for Helion-V1.5."""
def __init__(
self,
model_name: str = "DeepXR/Helion-V1.5",
device: str = "auto",
load_in_4bit: bool = False
):
"""
Initialize Helion-V1.5 model.
Args:
model_name: Model name or path
device: Device to load model on
load_in_4bit: Use 4-bit quantization
"""
from transformers import AutoTokenizer, AutoModelForCausalLM
logger.info(f"Loading Helion-V1.5: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
load_kwargs = {
"device_map": device,
"torch_dtype": torch.bfloat16,
"trust_remote_code": True
}
if load_in_4bit:
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
**load_kwargs
)
self.model.eval()
logger.info("Model loaded successfully")
def chat(
self,
messages: List[Dict[str, str]],
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
do_sample: bool = True
) -> str:
"""
Generate response from messages.
Args:
messages: List of message dicts with 'role' and 'content'
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
do_sample: Whether to use sampling
Returns:
Generated response text
"""
# Apply chat template
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(self.model.device)
# Generate
with torch.no_grad():
output = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode response
response = self.tokenizer.decode(
output[0][input_ids.shape[1]:],
skip_special_tokens=True
)
return response.strip()
def generate(
self,
prompt: str,
max_new_tokens: int = 512,
**kwargs
) -> str:
"""
Generate text from a simple prompt.
Args:
prompt: Input text
max_new_tokens: Maximum tokens to generate
**kwargs: Additional generation parameters
Returns:
Generated text
"""
messages = [{"role": "user", "content": prompt}]
return self.chat(messages, max_new_tokens=max_new_tokens, **kwargs)
def interactive(self):
"""Start interactive chat session."""
print("\n" + "="*60)
print("Helion-V1.5 Interactive Chat")
print("Type 'quit' or 'exit' to end")
print("="*60 + "\n")
conversation = []
while True:
user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit']:
print("Goodbye!")
break
if not user_input:
continue
conversation.append({"role": "user", "content": user_input})
try:
response = self.chat(conversation)
print(f"Helion: {response}\n")
conversation.append({"role": "assistant", "content": response})
except Exception as e:
print(f"Error: {e}")
conversation.pop() # Remove failed message
def main():
"""Main CLI interface."""
import argparse
parser = argparse.ArgumentParser(description="Helion-V1.5 Inference")
parser.add_argument("--model", default="DeepXR/Helion-V1.5")
parser.add_argument("--device", default="auto")
parser.add_argument("--4bit", action="store_true", help="Use 4-bit quantization")
parser.add_argument("--interactive", action="store_true", help="Interactive chat")
parser.add_argument("--prompt", type=str, help="Single prompt")
parser.add_argument("--max-tokens", type=int, default=512)
parser.add_argument("--temperature", type=float, default=0.7)
args = parser.parse_args()
# Initialize model
helion = HelionV15(
model_name=args.model,
device=args.device,
load_in_4bit=args.__dict__.get('4bit', False)
)
if args.interactive:
helion.interactive()
elif args.prompt:
response = helion.generate(
args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature
)
print(f"\nResponse:\n{response}")
else:
print("Use --interactive or --prompt")
if __name__ == "__main__":
main()