File size: 2,675 Bytes
7e72f06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
========================================
INFERENCE SCRIPT FOR MINI CODING AGENT
Load your fine-tuned Gemma-3-1B-IT coding model and chat with it.
========================================
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Change this to your trained model path or Hub ID
MODEL_PATH = "./gemma-code-agent-merged"
# MODEL_PATH = "YOUR_USERNAME/gemma-3-1b-code-agent"  # if pushed to Hub

def load_model(path: str):
    """Load the fine-tuned coding agent model."""
    print(f"Loading model from: {path}")
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

def chat(model, tokenizer, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7) -> str:
    """Generate a response for a coding prompt."""
    messages = [{"role": "user", "content": prompt}]
    
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        return_tensors="pt",
        add_generation_prompt=True,
        return_dict=True,
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=0.95,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    response = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    )
    return response


def interactive_chat(model, tokenizer):
    """Run an interactive chat loop."""
    print("\n" + "=" * 60)
    print("  MINI CODING AGENT - Interactive Chat")
    print("  Type 'exit' or 'quit' to stop")
    print("=" * 60 + "\n")
    
    while True:
        user_input = input("You: ").strip()
        if user_input.lower() in ("exit", "quit", "q"):
            print("Goodbye!")
            break
        
        print("\nAgent: ", end="", flush=True)
        response = chat(model, tokenizer, user_input)
        print(response)
        print("-" * 60)


if __name__ == "__main__":
    model, tokenizer = load_model(MODEL_PATH)
    
    # Quick test
    print("\nQuick test:")
    test = "Write a Python function to reverse a string without using built-in reverse methods."
    print(f"You: {test}")
    print(f"\nAgent: {chat(model, tokenizer, test)}")
    
    # Interactive mode
    interactive_chat(model, tokenizer)