gemma-mini-code-agent / inference.py
Abhay557's picture
Add standalone inference script
7e72f06 verified
Raw
History Blame Contribute Delete
2.68 kB
"""
========================================
INFERENCE SCRIPT FOR MINI CODING AGENT
Load your fine-tuned Gemma-3-1B-IT coding model and chat with it.
========================================
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Change this to your trained model path or Hub ID
MODEL_PATH = "./gemma-code-agent-merged"
# MODEL_PATH = "YOUR_USERNAME/gemma-3-1b-code-agent" # if pushed to Hub
def load_model(path: str):
"""Load the fine-tuned coding agent model."""
print(f"Loading model from: {path}")
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def chat(model, tokenizer, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7) -> str:
"""Generate a response for a coding prompt."""
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
return_tensors="pt",
add_generation_prompt=True,
return_dict=True,
).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[-1]:],
skip_special_tokens=True
)
return response
def interactive_chat(model, tokenizer):
"""Run an interactive chat loop."""
print("\n" + "=" * 60)
print(" MINI CODING AGENT - Interactive Chat")
print(" Type 'exit' or 'quit' to stop")
print("=" * 60 + "\n")
while True:
user_input = input("You: ").strip()
if user_input.lower() in ("exit", "quit", "q"):
print("Goodbye!")
break
print("\nAgent: ", end="", flush=True)
response = chat(model, tokenizer, user_input)
print(response)
print("-" * 60)
if __name__ == "__main__":
model, tokenizer = load_model(MODEL_PATH)
# Quick test
print("\nQuick test:")
test = "Write a Python function to reverse a string without using built-in reverse methods."
print(f"You: {test}")
print(f"\nAgent: {chat(model, tokenizer, test)}")
# Interactive mode
interactive_chat(model, tokenizer)