| """ |
| test_smollm3_robust.py - Test the robust SmolLM3-3B model |
| |
| This script tests our newly trained model on various schemas to measure |
| the dramatic improvement in function calling capability. |
| """ |
|
|
| import torch |
| import json |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
|
|
| def load_trained_model(): |
| """Load the robust trained model.""" |
| print("π Loading robust SmolLM3-3B model...") |
| |
| base_model_name = "HuggingFaceTB/SmolLM3-3B" |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| base_model = AutoModelForCausalLM.from_pretrained( |
| base_model_name, |
| torch_dtype=torch.float32, |
| trust_remote_code=True |
| ) |
| |
| |
| model = PeftModel.from_pretrained(base_model, "./smollm3_robust") |
| |
| |
| if torch.backends.mps.is_available(): |
| model = model.to("mps") |
| device = "mps" |
| else: |
| device = "cpu" |
| |
| print(f"β
Model loaded on {device}") |
| return model, tokenizer, device |
|
|
| def test_function_call(model, tokenizer, device, schema, question): |
| """Test the model on a specific schema and question.""" |
| |
| prompt = f"""<|im_start|>system |
| You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> |
| |
| <schema> |
| {json.dumps(schema, indent=2)} |
| </schema> |
| |
| <|im_start|>user |
| {question}<|im_end|> |
| <|im_start|>assistant |
| """ |
| |
| |
| inputs = tokenizer(prompt, return_tensors="pt") |
| if device == "mps": |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| |
| model.eval() |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=100, |
| temperature=0.1, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
| |
| |
| input_length = inputs["input_ids"].shape[1] |
| response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) |
| |
| |
| response = response.strip() |
| if response.endswith('}"}'): |
| response = response[:-2] |
| if response.endswith('}}'): |
| response = response[:-1] |
| |
| |
| try: |
| json_response = json.loads(response) |
| is_valid = True |
| |
| |
| has_name = "name" in json_response |
| has_args = "arguments" in json_response |
| correct_name = json_response.get("name") == schema["name"] |
| |
| score = sum([is_valid, has_name, has_args, correct_name]) |
| |
| except json.JSONDecodeError as e: |
| is_valid = False |
| json_response = None |
| score = 0 |
| |
| return response, is_valid, json_response, score |
|
|
| def main(): |
| print("π§ͺ Testing Robust SmolLM3-3B Function Calling") |
| print("=" * 55) |
| |
| |
| model, tokenizer, device = load_trained_model() |
| |
| |
| test_cases = [ |
| { |
| "name": "Stock Price (Training)", |
| "schema": { |
| "name": "get_stock_price", |
| "description": "Get current stock price for a ticker", |
| "parameters": { |
| "type": "object", |
| "properties": {"ticker": {"type": "string"}}, |
| "required": ["ticker"] |
| } |
| }, |
| "question": "What's Apple stock trading at?" |
| }, |
| { |
| "name": "Weather (Seen Pattern)", |
| "schema": { |
| "name": "get_weather", |
| "description": "Get weather for a location", |
| "parameters": { |
| "type": "object", |
| "properties": {"location": {"type": "string"}}, |
| "required": ["location"] |
| } |
| }, |
| "question": "How's the weather in Tokyo?" |
| }, |
| { |
| "name": "NEW: Database Query", |
| "schema": { |
| "name": "execute_sql", |
| "description": "Execute SQL query on database", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "query": {"type": "string"}, |
| "database": {"type": "string"} |
| }, |
| "required": ["query"] |
| } |
| }, |
| "question": "Find all users who registered this month" |
| }, |
| { |
| "name": "NEW: Complex Parameters", |
| "schema": { |
| "name": "book_flight", |
| "description": "Book a flight ticket", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "from_city": {"type": "string"}, |
| "to_city": {"type": "string"}, |
| "departure_date": {"type": "string"}, |
| "passengers": {"type": "integer"} |
| }, |
| "required": ["from_city", "to_city", "departure_date"] |
| } |
| }, |
| "question": "Book a flight from New York to London for December 15th" |
| }, |
| { |
| "name": "NEW: Financial Transaction", |
| "schema": { |
| "name": "transfer_funds", |
| "description": "Transfer money between accounts", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "amount": {"type": "number"}, |
| "from_account": {"type": "string"}, |
| "to_account": {"type": "string"}, |
| "memo": {"type": "string"} |
| }, |
| "required": ["amount", "from_account", "to_account"] |
| } |
| }, |
| "question": "Send $500 from checking to savings" |
| } |
| ] |
| |
| |
| total_score = 0 |
| max_score = len(test_cases) * 4 |
| valid_json_count = 0 |
| |
| for i, test_case in enumerate(test_cases, 1): |
| print(f"\nπ Test {i}: {test_case['name']}") |
| print(f"β Question: {test_case['question']}") |
| |
| response, is_valid, json_obj, score = test_function_call( |
| model, tokenizer, device, test_case['schema'], test_case['question'] |
| ) |
| |
| print(f"π€ Raw response: {response}") |
| |
| if is_valid: |
| print(f"β
Valid JSON: {json_obj}") |
| valid_json_count += 1 |
| else: |
| print(f"β Invalid JSON") |
| |
| print(f"π Score: {score}/4") |
| total_score += score |
| print("-" * 50) |
| |
| |
| print(f"\nπ FINAL RESULTS:") |
| print(f"β
Valid JSON responses: {valid_json_count}/{len(test_cases)} ({valid_json_count/len(test_cases)*100:.1f}%)") |
| print(f"π Overall score: {total_score}/{max_score} ({total_score/max_score*100:.1f}%)") |
| print(f"π― Success criteria: β₯80% valid calls") |
| |
| if valid_json_count/len(test_cases) >= 0.8: |
| print(f"π PASS - Excellent function calling capability!") |
| elif valid_json_count/len(test_cases) >= 0.6: |
| print(f"π‘ GOOD - Strong improvement, approaching target") |
| else: |
| print(f"π PROGRESS - Significant improvement from baseline") |
| |
| |
| print(f"\nπ IMPROVEMENT COMPARISON:") |
| print(f"Previous SmolLM2-1.7B result: 0/3 (0%)") |
| print(f"Current SmolLM3-3B result: {valid_json_count}/{len(test_cases)} ({valid_json_count/len(test_cases)*100:.1f}%)") |
| print(f"π Training loss improvement: 2.38 β 1.49 (37% better)") |
|
|
| if __name__ == "__main__": |
| main() |