| """ |
| constrained_generator.py - JSON Schema Constrained Generation |
| |
| This implements constrained decoding to force valid JSON output: |
| 1. Token-by-token validation against JSON schema |
| 2. Backtracking on invalid JSON syntax |
| 3. Beam search with JSON constraints |
| 4. Schema-aware generation |
| """ |
|
|
| import torch |
| import json |
| import jsonschema |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from typing import List, Dict, Any, Optional |
| import re |
|
|
| class ConstrainedJSONGenerator: |
| def __init__(self, model, tokenizer, device="mps"): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = device |
| self.model.eval() |
| |
| def is_valid_json_prefix(self, text: str) -> bool: |
| """Check if text could be the start of valid JSON.""" |
| text = text.strip() |
| if not text: |
| return True |
| |
| |
| if not text.startswith('{'): |
| return False |
| |
| |
| try: |
| json.loads(text) |
| return True |
| except json.JSONDecodeError as e: |
| |
| if "Expecting" in str(e) and "delimiter" in str(e): |
| |
| return True |
| return False |
| |
| def get_valid_next_tokens(self, current_text: str, schema: Dict) -> List[int]: |
| """Get tokens that would keep JSON valid.""" |
| valid_tokens = [] |
| |
| |
| vocab_size = len(self.tokenizer.vocab) |
| |
| for token_id in range(vocab_size): |
| if token_id == self.tokenizer.pad_token_id: |
| continue |
| |
| token_text = self.tokenizer.decode([token_id]) |
| new_text = current_text + token_text |
| |
| if self.is_valid_json_prefix(new_text): |
| valid_tokens.append(token_id) |
| |
| |
| if len(valid_tokens) > 50: |
| break |
| |
| return valid_tokens |
| |
| def generate_constrained(self, prompt: str, schema: Dict, max_length: int = 200) -> str: |
| """Generate text with JSON constraints.""" |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
| |
| generated_text = "" |
| current_input_ids = inputs['input_ids'].clone() |
| |
| for step in range(max_length): |
| |
| with torch.no_grad(): |
| outputs = self.model(current_input_ids) |
| logits = outputs.logits[0, -1, :] |
| |
| |
| valid_tokens = self.get_valid_next_tokens(generated_text, schema) |
| |
| if not valid_tokens: |
| |
| if not generated_text.strip().endswith('}'): |
| |
| next_token_id = self.tokenizer.encode('}')[0] |
| else: |
| break |
| else: |
| |
| masked_logits = logits.clone() |
| mask = torch.full_like(logits, float('-inf')) |
| mask[valid_tokens] = 0 |
| masked_logits = masked_logits + mask |
| |
| |
| probs = torch.softmax(masked_logits, dim=-1) |
| next_token_id = torch.multinomial(probs, 1).item() |
| |
| |
| current_input_ids = torch.cat([ |
| current_input_ids, |
| torch.tensor([[next_token_id]], device=self.device) |
| ], dim=1) |
| |
| |
| new_token = self.tokenizer.decode([next_token_id]) |
| generated_text += new_token |
| |
| |
| try: |
| parsed = json.loads(generated_text.strip()) |
| if self.validate_against_schema(parsed, schema): |
| break |
| except: |
| continue |
| |
| return generated_text.strip() |
| |
| def validate_against_schema(self, data: Dict, schema: Dict) -> bool: |
| """Validate JSON data against schema.""" |
| try: |
| jsonschema.validate(data, schema) |
| return True |
| except jsonschema.ValidationError: |
| return False |
| |
| def generate_with_beam_search(self, prompt: str, schema: Dict, num_beams: int = 3) -> str: |
| """Generate with beam search and JSON constraints.""" |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=150, |
| num_beams=num_beams, |
| early_stopping=True, |
| temperature=0.1, |
| do_sample=False, |
| pad_token_id=self.tokenizer.eos_token_id, |
| num_return_sequences=num_beams |
| ) |
| |
| |
| candidates = [] |
| for output in outputs: |
| generated_text = self.tokenizer.decode( |
| output[inputs['input_ids'].shape[1]:], |
| skip_special_tokens=True |
| ) |
| candidates.append(generated_text.strip()) |
| |
| |
| for candidate in candidates: |
| try: |
| parsed = json.loads(candidate) |
| if self.validate_against_schema(parsed, schema): |
| return candidate |
| except json.JSONDecodeError: |
| continue |
| |
| |
| return candidates[0] if candidates else "" |
|
|
| def create_json_schema_from_function(function_def: Dict) -> Dict: |
| """Create a JSON schema for validating function calls.""" |
| return { |
| "type": "object", |
| "properties": { |
| "name": { |
| "type": "string", |
| "const": function_def["name"] |
| }, |
| "arguments": function_def["parameters"] |
| }, |
| "required": ["name", "arguments"], |
| "additionalProperties": False |
| } |
|
|
| def test_constrained_generation(): |
| """Test the constrained generator.""" |
| print("π§ͺ Testing Constrained JSON Generation...") |
| |
| |
| model_name = "HuggingFaceTB/SmolLM3-3B" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float32, |
| device_map="mps" if torch.backends.mps.is_available() else "auto" |
| ) |
| |
| generator = ConstrainedJSONGenerator(model, tokenizer) |
| |
| |
| function_def = { |
| "name": "get_weather", |
| "description": "Get weather forecast", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "location": {"type": "string"}, |
| "days": {"type": "integer"} |
| }, |
| "required": ["location", "days"] |
| } |
| } |
| |
| schema = create_json_schema_from_function(function_def) |
| |
| prompt = f"""<|im_start|>system |
| You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> |
| |
| <schema> |
| {json.dumps(function_def, indent=2)} |
| </schema> |
| |
| <|im_start|>user |
| Get 3-day weather for New York<|im_end|> |
| <|im_start|>assistant |
| """ |
| |
| |
| print("π― Testing constrained generation...") |
| result = generator.generate_constrained(prompt, schema) |
| print(f"π€ Constrained result: {result}") |
| |
| |
| try: |
| parsed = json.loads(result) |
| generator.validate_against_schema(parsed, schema) |
| print("β
Valid JSON with correct schema!") |
| except Exception as e: |
| print(f"β Validation failed: {e}") |
| |
| |
| print("π― Testing beam search...") |
| beam_result = generator.generate_with_beam_search(prompt, schema) |
| print(f"π€ Beam result: {beam_result}") |
| |
| try: |
| parsed = json.loads(beam_result) |
| generator.validate_against_schema(parsed, schema) |
| print("β
Beam search produced valid JSON!") |
| except Exception as e: |
| print(f"β Beam validation failed: {e}") |
|
|
| if __name__ == "__main__": |
| test_constrained_generation() |