| """ |
| test_model.py - Test our trained dynamic function-calling agent |
| |
| This script loads the trained LoRA adapter and tests it on various schemas |
| to demonstrate zero-shot function calling capability. |
| """ |
|
|
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
| import json |
|
|
| def load_trained_model(): |
| """Load the base model and trained adapter.""" |
| print("π Loading trained model...") |
| |
| |
| base_model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct" |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| base_model = AutoModelForCausalLM.from_pretrained( |
| base_model_name, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None, |
| trust_remote_code=True |
| ) |
| |
| |
| model = PeftModel.from_pretrained(base_model, "./smollm_tool_adapter/checkpoint-6") |
| |
| print("β
Model loaded successfully!") |
| return model, tokenizer |
|
|
| def test_function_call(model, tokenizer, schema, question): |
| """Test the model on a specific schema and question.""" |
| |
| prompt = f"""<|im_start|>system |
| You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> |
| |
| <schema> |
| {json.dumps(schema, indent=2)} |
| </schema> |
| |
| <|im_start|>user |
| {question}<|im_end|> |
| <|im_start|>assistant |
| """ |
| |
| |
| inputs = tokenizer(prompt, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=100, |
| temperature=0.1, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
| |
| |
| response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) |
| |
| |
| try: |
| json_response = json.loads(response.strip()) |
| is_valid_json = True |
| except: |
| is_valid_json = False |
| json_response = None |
| |
| return response.strip(), is_valid_json, json_response |
|
|
| def main(): |
| print("π§ͺ Testing Dynamic Function-Calling Agent") |
| print("=" * 50) |
| |
| |
| model, tokenizer = load_trained_model() |
| |
| |
| test_cases = [ |
| { |
| "name": "Trained Schema: Stock Price", |
| "schema": { |
| "name": "get_stock_price", |
| "description": "Return the latest price for a given ticker symbol.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "ticker": {"type": "string"} |
| }, |
| "required": ["ticker"] |
| } |
| }, |
| "question": "What's Microsoft trading at?" |
| }, |
| { |
| "name": "NEW Schema: Database Query", |
| "schema": { |
| "name": "query_database", |
| "description": "Execute a SQL query on the database.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "query": {"type": "string"}, |
| "timeout": {"type": "number"} |
| }, |
| "required": ["query"] |
| } |
| }, |
| "question": "Find all users who signed up last week" |
| }, |
| { |
| "name": "NEW Schema: File Operations", |
| "schema": { |
| "name": "create_file", |
| "description": "Create a new file with content.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "filename": {"type": "string"}, |
| "content": {"type": "string"}, |
| "overwrite": {"type": "boolean"} |
| }, |
| "required": ["filename", "content"] |
| } |
| }, |
| "question": "Create a file called report.txt with the content 'Meeting notes'" |
| } |
| ] |
| |
| |
| valid_count = 0 |
| total_count = len(test_cases) |
| |
| for i, test_case in enumerate(test_cases, 1): |
| print(f"\nπ Test {i}: {test_case['name']}") |
| print(f"β Question: {test_case['question']}") |
| |
| response, is_valid, json_obj = test_function_call( |
| model, tokenizer, test_case['schema'], test_case['question'] |
| ) |
| |
| print(f"π€ Model response: {response}") |
| |
| if is_valid: |
| print(f"β
Valid JSON: {json_obj}") |
| valid_count += 1 |
| else: |
| print(f"β Invalid JSON") |
| |
| print("-" * 40) |
| |
| |
| print(f"\nπ Results Summary:") |
| print(f"β
Valid JSON responses: {valid_count}/{total_count} ({valid_count/total_count*100:.1f}%)") |
| print(f"π― Success criteria: β₯80% valid calls") |
| print(f"π Result: {'PASS' if valid_count/total_count >= 0.8 else 'NEEDS IMPROVEMENT'}") |
|
|
| if __name__ == "__main__": |
| main() |