"""inference.py - Code generation model wrapper for smolagents""" import torch from transformers import AutoModelForCausalLM, AutoTokenizer class CodeModel: def __init__(self, model_id: str, device: str = None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True) dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device, dtype=dtype) self.model.eval() def generate(self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7) -> str: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, top_p=0.9, repetition_penalty=1.2, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) new_tokens = outputs[0, inputs["input_ids"].shape[1]:] return self.tokenizer.decode(new_tokens, skip_special_tokens=False) def chat(self, messages: list[dict], max_new_tokens: int = 256) -> str: """Generate response using chat template.""" text = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) inputs = self.tokenizer(text, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2, ) new_tokens = outputs[0, inputs["input_ids"].shape[1]:] return self.tokenizer.decode(new_tokens, skip_special_tokens=False) if __name__ == "__main__": import os # Use local checkpoint if available, otherwise HuggingFace model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base" model = CodeModel(model_id) # Example: Generate code result = model.generate("Write a Python function to calculate factorial") print("Generated code:") print(result) # Example: Chat messages = [ {"role": "system", "content": "You are a helpful coding assistant."}, {"role": "user", "content": "Write a function to reverse a string"} ] response = model.chat(messages) print("\nChat response:") print(response)