| """ |
| test_constrained_model_spaces.py - SPACES-OPTIMIZED Constrained Generation |
| |
| Ultra-aggressive optimization for Hugging Face Spaces environment |
| """ |
|
|
| import torch |
| import json |
| import jsonschema |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from typing import Dict |
| import time |
| import threading |
|
|
| class TimeoutException(Exception): |
| pass |
|
|
| def load_trained_model(): |
| """Load our model - SPACES OPTIMIZED""" |
| print("🔄 Loading SmolLM3-3B Function-Calling Agent...") |
| |
| base_model_name = "HuggingFaceTB/SmolLM3-3B" |
| |
| try: |
| print("🔄 Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| print("🔄 Loading base model...") |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| base_model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| low_cpu_mem_usage=True |
| ) |
| |
| |
| adapter_paths = [ |
| "jlov7/SmolLM3-Function-Calling-LoRA", |
| "./model_files", |
| "./smollm3_robust", |
| "./hub_upload", |
| ] |
| |
| model_loaded = False |
| for i, adapter_path in enumerate(adapter_paths): |
| try: |
| if i == 0: |
| print("🔄 Loading fine-tuned adapter from Hugging Face Hub...") |
| else: |
| print(f"🔄 Trying local path: {adapter_path}") |
| |
| from peft import PeftModel |
| model = PeftModel.from_pretrained(model, adapter_path) |
| model = model.merge_and_unload() |
| |
| if i == 0: |
| print("✅ Fine-tuned model loaded successfully from Hub!") |
| else: |
| print(f"✅ Fine-tuned model loaded successfully from {adapter_path}!") |
| model_loaded = True |
| break |
| |
| except Exception as e: |
| if i == 0: |
| print(f"⚠️ Hub adapter not found: {e}") |
| else: |
| print(f"⚠️ Path {adapter_path} failed: {e}") |
| continue |
| |
| if not model_loaded: |
| print("🔧 Using base model with optimized prompting") |
| |
| print("✅ Model loaded successfully") |
| return model, tokenizer |
| |
| except Exception as e: |
| print(f"❌ Error loading model: {e}") |
| raise |
|
|
| def constrained_json_generate(model, tokenizer, prompt: str, schema: Dict, max_attempts: int = 2): |
| """SPACES-OPTIMIZED generation with aggressive timeouts""" |
| device = next(model.parameters()).device |
| |
| for attempt in range(max_attempts): |
| try: |
| |
| temperature = 0.1 + (attempt * 0.2) |
| |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| |
| |
| result = [None] |
| error = [None] |
| |
| def generate_with_timeout(): |
| try: |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=25, |
| temperature=temperature, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| num_return_sequences=1, |
| use_cache=True, |
| repetition_penalty=1.2 |
| ) |
| result[0] = outputs |
| except Exception as e: |
| error[0] = str(e) |
| |
| |
| thread = threading.Thread(target=generate_with_timeout) |
| thread.daemon = True |
| thread.start() |
| thread.join(timeout=4) |
| |
| if thread.is_alive(): |
| return "", False, f"Generation timed out (attempt {attempt + 1})" |
| |
| if error[0]: |
| return "", False, f"Generation error: {error[0]}" |
| |
| if result[0] is None: |
| return "", False, f"Generation failed (attempt {attempt + 1})" |
| |
| outputs = result[0] |
| |
| |
| generated_ids = outputs[0][inputs['input_ids'].shape[1]:] |
| response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
| |
| |
| if "{" in response and "}" in response: |
| start = response.find("{") |
| bracket_count = 0 |
| end = start |
| |
| for i, char in enumerate(response[start:], start): |
| if char == "{": |
| bracket_count += 1 |
| elif char == "}": |
| bracket_count -= 1 |
| if bracket_count == 0: |
| end = i + 1 |
| break |
| |
| json_str = response[start:end] |
| else: |
| json_str = response |
| |
| |
| try: |
| parsed = json.loads(json_str) |
| jsonschema.validate(parsed, schema) |
| return json_str, True, None |
| except (json.JSONDecodeError, jsonschema.ValidationError) as e: |
| if attempt == max_attempts - 1: |
| return json_str, False, f"JSON validation failed: {str(e)}" |
| continue |
| |
| except Exception as e: |
| if attempt == max_attempts - 1: |
| return "", False, f"Generation error: {str(e)}" |
| continue |
| |
| return "", False, "All generation attempts failed" |
|
|
| def create_json_schema(function_def: Dict) -> Dict: |
| """Create JSON schema for function definition""" |
| return { |
| "type": "object", |
| "properties": { |
| "name": { |
| "type": "string", |
| "enum": [function_def["name"]] |
| }, |
| "arguments": function_def["parameters"] |
| }, |
| "required": ["name", "arguments"] |
| } |
|
|
| def create_test_schemas(): |
| """Create simplified test schemas""" |
| return { |
| "weather_forecast": { |
| "name": "get_weather_forecast", |
| "description": "Get weather forecast", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "location": {"type": "string"}, |
| "days": {"type": "integer"} |
| }, |
| "required": ["location", "days"] |
| } |
| } |
| } |
|
|
| |
| if __name__ == "__main__": |
| print("🧪 Testing SPACES-optimized model...") |
| try: |
| model, tokenizer = load_trained_model() |
| |
| test_schema = create_test_schemas()["weather_forecast"] |
| schema = create_json_schema(test_schema) |
| |
| prompt = """<|im_start|>system |
| You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> |
| |
| <schema> |
| {"name": "get_weather_forecast", "description": "Get weather forecast", "parameters": {"type": "object", "properties": {"location": {"type": "string"}, "days": {"type": "integer"}}, "required": ["location", "days"]}} |
| </schema> |
| |
| <|im_start|>user |
| Get weather for Tokyo for 5 days<|im_end|> |
| <|im_start|>assistant |
| """ |
| |
| result, success, error = constrained_json_generate(model, tokenizer, prompt, schema) |
| print(f"✅ Result: {result}") |
| print(f"✅ Success: {success}") |
| if error: |
| print(f"⚠️ Error: {error}") |
| |
| except Exception as e: |
| print(f"❌ Test failed: {e}") |