Trouter-Library commited on
Commit
01b0ba7
·
verified ·
1 Parent(s): d06b515

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +143 -0
inference.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-OSC Inference Script
3
+ DeepXR/Helion-OSC - Mathematical Coding Language Model
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from typing import Optional, Dict, Any
9
+
10
+
11
+ class HelionOSCInference:
12
+ """Inference wrapper for Helion-OSC model"""
13
+
14
+ def __init__(
15
+ self,
16
+ model_name: str = "DeepXR/Helion-OSC",
17
+ device: Optional[str] = None,
18
+ load_in_8bit: bool = False
19
+ ):
20
+ """
21
+ Initialize the Helion-OSC model
22
+
23
+ Args:
24
+ model_name: HuggingFace model identifier
25
+ device: Device to load model on (cuda/cpu)
26
+ load_in_8bit: Whether to load model in 8-bit precision
27
+ """
28
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ print(f"Loading Helion-OSC on {self.device}...")
31
+
32
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+
34
+ model_kwargs = {"device_map": "auto"} if self.device == "cuda" else {}
35
+ if load_in_8bit:
36
+ model_kwargs["load_in_8bit"] = True
37
+
38
+ self.model = AutoModelForCausalLM.from_pretrained(
39
+ model_name,
40
+ torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
41
+ **model_kwargs
42
+ )
43
+
44
+ if self.device == "cpu":
45
+ self.model = self.model.to(self.device)
46
+
47
+ self.model.eval()
48
+ print("Model loaded successfully!")
49
+
50
+ def generate(
51
+ self,
52
+ prompt: str,
53
+ max_length: int = 512,
54
+ temperature: float = 0.7,
55
+ top_p: float = 0.95,
56
+ top_k: int = 50,
57
+ num_return_sequences: int = 1,
58
+ do_sample: bool = True,
59
+ **kwargs
60
+ ) -> str:
61
+ """
62
+ Generate code or text based on prompt
63
+
64
+ Args:
65
+ prompt: Input prompt
66
+ max_length: Maximum length of generated text
67
+ temperature: Sampling temperature
68
+ top_p: Nucleus sampling parameter
69
+ top_k: Top-k sampling parameter
70
+ num_return_sequences: Number of sequences to generate
71
+ do_sample: Whether to use sampling
72
+
73
+ Returns:
74
+ Generated text
75
+ """
76
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
77
+
78
+ with torch.no_grad():
79
+ outputs = self.model.generate(
80
+ **inputs,
81
+ max_length=max_length,
82
+ temperature=temperature,
83
+ top_p=top_p,
84
+ top_k=top_k,
85
+ num_return_sequences=num_return_sequences,
86
+ do_sample=do_sample,
87
+ pad_token_id=self.tokenizer.eos_token_id,
88
+ **kwargs
89
+ )
90
+
91
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
92
+ return generated_text
93
+
94
+ def code_generation(self, prompt: str, max_length: int = 1024) -> str:
95
+ """Optimized for code generation tasks"""
96
+ return self.generate(
97
+ prompt,
98
+ max_length=max_length,
99
+ temperature=0.7,
100
+ top_p=0.95,
101
+ do_sample=True
102
+ )
103
+
104
+ def mathematical_reasoning(self, prompt: str, max_length: int = 512) -> str:
105
+ """Optimized for mathematical reasoning tasks"""
106
+ return self.generate(
107
+ prompt,
108
+ max_length=max_length,
109
+ temperature=0.3,
110
+ top_p=0.9,
111
+ do_sample=False
112
+ )
113
+
114
+
115
+ def main():
116
+ """Example usage"""
117
+ # Initialize model
118
+ helion = HelionOSCInference()
119
+
120
+ # Example 1: Code generation
121
+ code_prompt = "Write a Python function to calculate the factorial of a number using recursion:"
122
+ print("\n=== Code Generation ===")
123
+ print(f"Prompt: {code_prompt}")
124
+ result = helion.code_generation(code_prompt)
125
+ print(f"Output:\n{result}\n")
126
+
127
+ # Example 2: Mathematical reasoning
128
+ math_prompt = "Prove that the sum of first n natural numbers is n(n+1)/2:"
129
+ print("\n=== Mathematical Reasoning ===")
130
+ print(f"Prompt: {math_prompt}")
131
+ result = helion.mathematical_reasoning(math_prompt)
132
+ print(f"Output:\n{result}\n")
133
+
134
+ # Example 3: Algorithm design
135
+ algo_prompt = "Design an efficient algorithm to find the longest palindromic substring:"
136
+ print("\n=== Algorithm Design ===")
137
+ print(f"Prompt: {algo_prompt}")
138
+ result = helion.generate(algo_prompt, max_length=1024)
139
+ print(f"Output:\n{result}\n")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()