|
|
""" |
|
|
Helion-V1.5 Inference Script |
|
|
Simple interface for using the model |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import logging |
|
|
from typing import List, Dict, Optional |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class HelionV15: |
|
|
"""Easy-to-use interface for Helion-V1.5.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "DeepXR/Helion-V1.5", |
|
|
device: str = "auto", |
|
|
load_in_4bit: bool = False |
|
|
): |
|
|
""" |
|
|
Initialize Helion-V1.5 model. |
|
|
|
|
|
Args: |
|
|
model_name: Model name or path |
|
|
device: Device to load model on |
|
|
load_in_4bit: Use 4-bit quantization |
|
|
""" |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
logger.info(f"Loading Helion-V1.5: {model_name}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
load_kwargs = { |
|
|
"device_map": device, |
|
|
"torch_dtype": torch.bfloat16, |
|
|
"trust_remote_code": True |
|
|
} |
|
|
|
|
|
if load_in_4bit: |
|
|
from transformers import BitsAndBytesConfig |
|
|
load_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
**load_kwargs |
|
|
) |
|
|
|
|
|
self.model.eval() |
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9, |
|
|
do_sample: bool = True |
|
|
) -> str: |
|
|
""" |
|
|
Generate response from messages. |
|
|
|
|
|
Args: |
|
|
messages: List of message dicts with 'role' and 'content' |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
top_p: Nucleus sampling parameter |
|
|
do_sample: Whether to use sampling |
|
|
|
|
|
Returns: |
|
|
Generated response text |
|
|
""" |
|
|
|
|
|
input_ids = self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt" |
|
|
).to(self.model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=do_sample, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
response = self.tokenizer.decode( |
|
|
output[0][input_ids.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return response.strip() |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
max_new_tokens: int = 512, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Generate text from a simple prompt. |
|
|
|
|
|
Args: |
|
|
prompt: Input text |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Generated text |
|
|
""" |
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
return self.chat(messages, max_new_tokens=max_new_tokens, **kwargs) |
|
|
|
|
|
def interactive(self): |
|
|
"""Start interactive chat session.""" |
|
|
print("\n" + "="*60) |
|
|
print("Helion-V1.5 Interactive Chat") |
|
|
print("Type 'quit' or 'exit' to end") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
conversation = [] |
|
|
|
|
|
while True: |
|
|
user_input = input("You: ").strip() |
|
|
|
|
|
if user_input.lower() in ['quit', 'exit']: |
|
|
print("Goodbye!") |
|
|
break |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
conversation.append({"role": "user", "content": user_input}) |
|
|
|
|
|
try: |
|
|
response = self.chat(conversation) |
|
|
print(f"Helion: {response}\n") |
|
|
|
|
|
conversation.append({"role": "assistant", "content": response}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
conversation.pop() |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main CLI interface.""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Helion-V1.5 Inference") |
|
|
parser.add_argument("--model", default="DeepXR/Helion-V1.5") |
|
|
parser.add_argument("--device", default="auto") |
|
|
parser.add_argument("--4bit", action="store_true", help="Use 4-bit quantization") |
|
|
parser.add_argument("--interactive", action="store_true", help="Interactive chat") |
|
|
parser.add_argument("--prompt", type=str, help="Single prompt") |
|
|
parser.add_argument("--max-tokens", type=int, default=512) |
|
|
parser.add_argument("--temperature", type=float, default=0.7) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
helion = HelionV15( |
|
|
model_name=args.model, |
|
|
device=args.device, |
|
|
load_in_4bit=args.__dict__.get('4bit', False) |
|
|
) |
|
|
|
|
|
if args.interactive: |
|
|
helion.interactive() |
|
|
elif args.prompt: |
|
|
response = helion.generate( |
|
|
args.prompt, |
|
|
max_new_tokens=args.max_tokens, |
|
|
temperature=args.temperature |
|
|
) |
|
|
print(f"\nResponse:\n{response}") |
|
|
else: |
|
|
print("Use --interactive or --prompt") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |