Update inference.py
Browse files- inference.py +11 -39
inference.py
CHANGED
|
@@ -1,42 +1,14 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
Load Sentinel model and tokenizer.
|
| 6 |
-
"""
|
| 7 |
-
print(f"Loading {model_name}...")
|
| 8 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 9 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 10 |
-
model_name,
|
| 11 |
-
device_map="auto", # Uses GPU if available
|
| 12 |
-
trust_remote_code=True
|
| 13 |
-
)
|
| 14 |
-
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
| 15 |
-
return generator
|
| 16 |
-
|
| 17 |
-
def code_with_sentinel(prompt, generator, max_new_tokens=200):
|
| 18 |
-
"""
|
| 19 |
-
Generate code from a natural language prompt.
|
| 20 |
-
"""
|
| 21 |
-
print(f"\nPrompt: {prompt}\n")
|
| 22 |
-
output = generator(
|
| 23 |
-
prompt,
|
| 24 |
-
max_new_tokens=max_new_tokens,
|
| 25 |
-
do_sample=True,
|
| 26 |
-
top_p=0.9,
|
| 27 |
-
temperature=0.7,
|
| 28 |
-
eos_token_id=generator.tokenizer.eos_token_id
|
| 29 |
-
)
|
| 30 |
-
result = output[0]["generated_text"]
|
| 31 |
-
# Return only new code, not the full prompt
|
| 32 |
-
return result[len(prompt):].strip()
|
| 33 |
|
| 34 |
if __name__ == "__main__":
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
code =
|
| 40 |
-
|
| 41 |
-
print("Generated Code:\n")
|
| 42 |
-
print(code)
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from model_loader import load_model
|
| 3 |
+
from generator import generate_code
|
| 4 |
+
from utils import pretty_print_code
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
if __name__ == "__main__":
|
| 7 |
+
parser = argparse.ArgumentParser(description="Run Sentinel for code generation")
|
| 8 |
+
parser.add_argument("prompt", type=str, help="The coding prompt to generate code for")
|
| 9 |
+
parser.add_argument("--max_new_tokens", type=int, default=200, help="Maximum tokens to generate")
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
|
| 12 |
+
generator = load_model("your-username/sentinel")
|
| 13 |
+
code = generate_code(args.prompt, generator, max_new_tokens=args.max_new_tokens)
|
| 14 |
+
pretty_print_code(code)
|
|
|
|
|
|