|
|
"""inference.py - Code generation model wrapper for smolagents""" |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
class CodeModel: |
|
|
def __init__(self, model_id: str, device: str = None): |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True) |
|
|
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device, dtype=dtype) |
|
|
self.model.eval() |
|
|
|
|
|
def generate(self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7) -> str: |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.2, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
new_tokens = outputs[0, inputs["input_ids"].shape[1]:] |
|
|
return self.tokenizer.decode(new_tokens, skip_special_tokens=False) |
|
|
|
|
|
def chat(self, messages: list[dict], max_new_tokens: int = 256) -> str: |
|
|
"""Generate response using chat template.""" |
|
|
text = self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False |
|
|
) |
|
|
inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.2, |
|
|
) |
|
|
|
|
|
new_tokens = outputs[0, inputs["input_ids"].shape[1]:] |
|
|
return self.tokenizer.decode(new_tokens, skip_special_tokens=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import os |
|
|
|
|
|
model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base" |
|
|
model = CodeModel(model_id) |
|
|
|
|
|
|
|
|
result = model.generate("Write a Python function to calculate factorial") |
|
|
print("Generated code:") |
|
|
print(result) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a helpful coding assistant."}, |
|
|
{"role": "user", "content": "Write a function to reverse a string"} |
|
|
] |
|
|
response = model.chat(messages) |
|
|
print("\nChat response:") |
|
|
print(response) |
|
|
|