| """ |
| ======================================== |
| INFERENCE SCRIPT FOR MINI CODING AGENT |
| Load your fine-tuned Gemma-3-1B-IT coding model and chat with it. |
| ======================================== |
| """ |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| |
| MODEL_PATH = "./gemma-code-agent-merged" |
| |
|
|
| def load_model(path: str): |
| """Load the fine-tuned coding agent model.""" |
| print(f"Loading model from: {path}") |
| tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| path, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| return model, tokenizer |
|
|
| def chat(model, tokenizer, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7) -> str: |
| """Generate a response for a coding prompt.""" |
| messages = [{"role": "user", "content": prompt}] |
| |
| inputs = tokenizer.apply_chat_template( |
| messages, |
| tokenize=True, |
| return_tensors="pt", |
| add_generation_prompt=True, |
| return_dict=True, |
| ).to(model.device) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p=0.95, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| |
| response = tokenizer.decode( |
| outputs[0][inputs["input_ids"].shape[-1]:], |
| skip_special_tokens=True |
| ) |
| return response |
|
|
|
|
| def interactive_chat(model, tokenizer): |
| """Run an interactive chat loop.""" |
| print("\n" + "=" * 60) |
| print(" MINI CODING AGENT - Interactive Chat") |
| print(" Type 'exit' or 'quit' to stop") |
| print("=" * 60 + "\n") |
| |
| while True: |
| user_input = input("You: ").strip() |
| if user_input.lower() in ("exit", "quit", "q"): |
| print("Goodbye!") |
| break |
| |
| print("\nAgent: ", end="", flush=True) |
| response = chat(model, tokenizer, user_input) |
| print(response) |
| print("-" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| model, tokenizer = load_model(MODEL_PATH) |
| |
| |
| print("\nQuick test:") |
| test = "Write a Python function to reverse a string without using built-in reverse methods." |
| print(f"You: {test}") |
| print(f"\nAgent: {chat(model, tokenizer, test)}") |
| |
| |
| interactive_chat(model, tokenizer) |
|
|