Trouter-Library commited on
Commit
05fe834
·
verified ·
1 Parent(s): 8a776ef

Create inference_v15.py

Browse files
Files changed (1) hide show
  1. inference_v15.py +194 -0
inference_v15.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-V1.5 Inference Script
3
+ Simple interface for using the model
4
+ """
5
+
6
+ import torch
7
+ import logging
8
+ from typing import List, Dict, Optional
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class HelionV15:
15
+ """Easy-to-use interface for Helion-V1.5."""
16
+
17
+ def __init__(
18
+ self,
19
+ model_name: str = "DeepXR/Helion-V1.5",
20
+ device: str = "auto",
21
+ load_in_4bit: bool = False
22
+ ):
23
+ """
24
+ Initialize Helion-V1.5 model.
25
+
26
+ Args:
27
+ model_name: Model name or path
28
+ device: Device to load model on
29
+ load_in_4bit: Use 4-bit quantization
30
+ """
31
+ from transformers import AutoTokenizer, AutoModelForCausalLM
32
+
33
+ logger.info(f"Loading Helion-V1.5: {model_name}")
34
+
35
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+
37
+ load_kwargs = {
38
+ "device_map": device,
39
+ "torch_dtype": torch.bfloat16,
40
+ "trust_remote_code": True
41
+ }
42
+
43
+ if load_in_4bit:
44
+ from transformers import BitsAndBytesConfig
45
+ load_kwargs["quantization_config"] = BitsAndBytesConfig(
46
+ load_in_4bit=True,
47
+ bnb_4bit_compute_dtype=torch.bfloat16
48
+ )
49
+
50
+ self.model = AutoModelForCausalLM.from_pretrained(
51
+ model_name,
52
+ **load_kwargs
53
+ )
54
+
55
+ self.model.eval()
56
+ logger.info("Model loaded successfully")
57
+
58
+ def chat(
59
+ self,
60
+ messages: List[Dict[str, str]],
61
+ max_new_tokens: int = 512,
62
+ temperature: float = 0.7,
63
+ top_p: float = 0.9,
64
+ do_sample: bool = True
65
+ ) -> str:
66
+ """
67
+ Generate response from messages.
68
+
69
+ Args:
70
+ messages: List of message dicts with 'role' and 'content'
71
+ max_new_tokens: Maximum tokens to generate
72
+ temperature: Sampling temperature
73
+ top_p: Nucleus sampling parameter
74
+ do_sample: Whether to use sampling
75
+
76
+ Returns:
77
+ Generated response text
78
+ """
79
+ # Apply chat template
80
+ input_ids = self.tokenizer.apply_chat_template(
81
+ messages,
82
+ add_generation_prompt=True,
83
+ return_tensors="pt"
84
+ ).to(self.model.device)
85
+
86
+ # Generate
87
+ with torch.no_grad():
88
+ output = self.model.generate(
89
+ input_ids,
90
+ max_new_tokens=max_new_tokens,
91
+ temperature=temperature,
92
+ top_p=top_p,
93
+ do_sample=do_sample,
94
+ pad_token_id=self.tokenizer.pad_token_id,
95
+ eos_token_id=self.tokenizer.eos_token_id
96
+ )
97
+
98
+ # Decode response
99
+ response = self.tokenizer.decode(
100
+ output[0][input_ids.shape[1]:],
101
+ skip_special_tokens=True
102
+ )
103
+
104
+ return response.strip()
105
+
106
+ def generate(
107
+ self,
108
+ prompt: str,
109
+ max_new_tokens: int = 512,
110
+ **kwargs
111
+ ) -> str:
112
+ """
113
+ Generate text from a simple prompt.
114
+
115
+ Args:
116
+ prompt: Input text
117
+ max_new_tokens: Maximum tokens to generate
118
+ **kwargs: Additional generation parameters
119
+
120
+ Returns:
121
+ Generated text
122
+ """
123
+ messages = [{"role": "user", "content": prompt}]
124
+ return self.chat(messages, max_new_tokens=max_new_tokens, **kwargs)
125
+
126
+ def interactive(self):
127
+ """Start interactive chat session."""
128
+ print("\n" + "="*60)
129
+ print("Helion-V1.5 Interactive Chat")
130
+ print("Type 'quit' or 'exit' to end")
131
+ print("="*60 + "\n")
132
+
133
+ conversation = []
134
+
135
+ while True:
136
+ user_input = input("You: ").strip()
137
+
138
+ if user_input.lower() in ['quit', 'exit']:
139
+ print("Goodbye!")
140
+ break
141
+
142
+ if not user_input:
143
+ continue
144
+
145
+ conversation.append({"role": "user", "content": user_input})
146
+
147
+ try:
148
+ response = self.chat(conversation)
149
+ print(f"Helion: {response}\n")
150
+
151
+ conversation.append({"role": "assistant", "content": response})
152
+
153
+ except Exception as e:
154
+ print(f"Error: {e}")
155
+ conversation.pop() # Remove failed message
156
+
157
+
158
+ def main():
159
+ """Main CLI interface."""
160
+ import argparse
161
+
162
+ parser = argparse.ArgumentParser(description="Helion-V1.5 Inference")
163
+ parser.add_argument("--model", default="DeepXR/Helion-V1.5")
164
+ parser.add_argument("--device", default="auto")
165
+ parser.add_argument("--4bit", action="store_true", help="Use 4-bit quantization")
166
+ parser.add_argument("--interactive", action="store_true", help="Interactive chat")
167
+ parser.add_argument("--prompt", type=str, help="Single prompt")
168
+ parser.add_argument("--max-tokens", type=int, default=512)
169
+ parser.add_argument("--temperature", type=float, default=0.7)
170
+
171
+ args = parser.parse_args()
172
+
173
+ # Initialize model
174
+ helion = HelionV15(
175
+ model_name=args.model,
176
+ device=args.device,
177
+ load_in_4bit=args.__dict__.get('4bit', False)
178
+ )
179
+
180
+ if args.interactive:
181
+ helion.interactive()
182
+ elif args.prompt:
183
+ response = helion.generate(
184
+ args.prompt,
185
+ max_new_tokens=args.max_tokens,
186
+ temperature=args.temperature
187
+ )
188
+ print(f"\nResponse:\n{response}")
189
+ else:
190
+ print("Use --interactive or --prompt")
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()