pynb-73m-base / inference.py
AutomatedScientist's picture
Upload folder using huggingface_hub
9ab70a9 verified
"""inference.py - Code generation model wrapper for smolagents"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class CodeModel:
def __init__(self, model_id: str, device: str = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True)
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device, dtype=dtype)
self.model.eval()
def generate(self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
new_tokens = outputs[0, inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=False)
def chat(self, messages: list[dict], max_new_tokens: int = 256) -> str:
"""Generate response using chat template."""
text = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False
)
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
)
new_tokens = outputs[0, inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=False)
if __name__ == "__main__":
import os
# Use local checkpoint if available, otherwise HuggingFace
model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base"
model = CodeModel(model_id)
# Example: Generate code
result = model.generate("Write a Python function to calculate factorial")
print("Generated code:")
print(result)
# Example: Chat
messages = [
{"role": "system", "content": "You are a helpful coding assistant."},
{"role": "user", "content": "Write a function to reverse a string"}
]
response = model.chat(messages)
print("\nChat response:")
print(response)