Helion-V1 / inference.py
Trouter-Library's picture
Create inference.py
b42e229 verified
"""
Helion-V1 Inference Script
Safe and helpful conversational AI model
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict
import warnings
warnings.filterwarnings('ignore')
class HelionInference:
def __init__(self, model_name: str = "DeepXR/Helion-V1", device: str = "auto"):
"""
Initialize the Helion model for inference.
Args:
model_name: HuggingFace model identifier
device: Device to run inference on ('cuda', 'cpu', or 'auto')
"""
print(f"Loading Helion-V1 model from {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
trust_remote_code=True
)
self.model.eval()
print("Model loaded successfully!")
# Safety keywords to monitor
self.safety_keywords = [
"harm", "illegal", "weapon", "violence", "dangerous",
"exploit", "hack", "steal", "abuse"
]
def check_safety(self, text: str) -> bool:
"""
Basic safety check on input text.
Args:
text: Input text to check
Returns:
True if text appears safe, False otherwise
"""
text_lower = text.lower()
for keyword in self.safety_keywords:
if keyword in text_lower:
return False
return True
def generate_response(
self,
messages: List[Dict[str, str]],
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
do_sample: bool = True
) -> str:
"""
Generate a response from the model.
Args:
messages: List of message dictionaries with 'role' and 'content'
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
do_sample: Whether to use sampling
Returns:
Generated response text
"""
# Apply chat template
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(self.model.device)
# Generate response
with torch.no_grad():
output = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode response
response = self.tokenizer.decode(
output[0][input_ids.shape[1]:],
skip_special_tokens=True
)
return response.strip()
def chat(self):
"""Interactive chat mode."""
print("\n" + "="*60)
print("Helion-V1 Interactive Chat")
print("Type 'quit' or 'exit' to end the conversation")
print("="*60 + "\n")
conversation_history = []
while True:
user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit']:
print("Goodbye! Have a great day!")
break
if not user_input:
continue
# Basic safety check
if not self.check_safety(user_input):
print("Helion: I apologize, but I can't assist with that request. "
"Let me know if there's something else I can help you with!")
continue
# Add user message to history
conversation_history.append({
"role": "user",
"content": user_input
})
# Generate response
try:
response = self.generate_response(conversation_history)
print(f"Helion: {response}\n")
# Add assistant response to history
conversation_history.append({
"role": "assistant",
"content": response
})
except Exception as e:
print(f"Error generating response: {e}")
conversation_history.pop() # Remove failed user message
def main():
"""Main function for CLI usage."""
import argparse
parser = argparse.ArgumentParser(description="Helion-V1 Inference")
parser.add_argument("--model", default="DeepXR/Helion-V1", help="Model name or path")
parser.add_argument("--device", default="auto", help="Device to use (cuda/cpu/auto)")
parser.add_argument("--interactive", action="store_true", help="Start interactive chat")
parser.add_argument("--prompt", type=str, help="Single prompt to process")
args = parser.parse_args()
# Initialize model
helion = HelionInference(model_name=args.model, device=args.device)
if args.interactive:
helion.chat()
elif args.prompt:
messages = [{"role": "user", "content": args.prompt}]
response = helion.generate_response(messages)
print(f"Response: {response}")
else:
print("Please specify --interactive or --prompt")
if __name__ == "__main__":
main()