"""Self-contained inference for SLM Function Calling on HuggingFace Spaces.""" from __future__ import annotations import json import os from pathlib import Path from typing import Any import torch from huggingface_hub import snapshot_download from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer # System prompt for function calling SYSTEM_PROMPT = ( "You are a helpful assistant. You have to either provide a way to answer " "user's request or answer user's query." ) def parse_function_call(response: str) -> dict[str, Any]: """Parse model output to extract function call. Parses the model's response in the format: {"name": "...", "arguments": "..."} <|im_end|> :param response: Raw model output string :return: Dict with 'fn_name' and 'properties' keys, or 'error' key if parsing fails """ try: # Define delimiters start_delim = " " end_delim = "<|im_end|>" # Find the JSON portion between delimiters start_idx = response.find(start_delim) if start_idx == -1: return {"error": "Start delimiter ' ' not found"} start_idx += len(start_delim) end_idx = response.find(end_delim, start_idx) if end_idx == -1: return {"error": "End delimiter '<|im_end|>' not found"} # Extract the JSON string json_str = response[start_idx:end_idx].strip() # Parse the outer JSON (contains name and arguments) function_call_dict = json.loads(json_str) # Extract function name and arguments fn_name = function_call_dict.get("name") if fn_name is None: return {"error": "Function name not found in response"} arguments_str = function_call_dict.get("arguments", "{}") # Handle arguments - convert Python-style to JSON-style if isinstance(arguments_str, str): # Replace Python boolean/None syntax with JSON syntax arguments_str = arguments_str.replace("'", '"') arguments_str = arguments_str.replace("True", "true") arguments_str = arguments_str.replace("False", "false") arguments_str = arguments_str.replace("None", "null") properties = json.loads(arguments_str) elif isinstance(arguments_str, dict): properties = arguments_str else: properties = {} return {"fn_name": fn_name, "properties": properties} except json.JSONDecodeError as e: return {"error": f"JSON parsing error: {e}"} except Exception as e: return {"error": str(e)} class Inferencer: """Inference class for SLM Function Calling model. Downloads LoRA adapter from HuggingFace Hub on initialization, or loads from a local directory if specified. Configuration via environment variables: - HF_MODEL_REPO: HuggingFace Hub repo ID (e.g., 'username/gpt2-fc-adapter') - LOCAL_CHECKPOINT_DIR: Local directory path (overrides HF_MODEL_REPO) - BASE_MODEL: Base model name (default: 'gpt2') Example:: # Set environment variable os.environ["HF_MODEL_REPO"] = "suyash94/gpt2-fc-adapter" inferencer = Inferencer() result = inferencer.predict("Set the temperature to 22 degrees") print(result["parsed"]) # {"fn_name": "set_temperature", "properties": {...}} """ def __init__( self, repo_id: str | None = None, local_dir: str | Path | None = None, base_model: str | None = None, device: torch.device | str | None = None, cache_dir: str | None = None, ) -> None: """Initialize the inferencer. :param repo_id: HuggingFace Hub repo ID for LoRA adapter :param local_dir: Local directory containing adapter files :param base_model: Base model name (default: gpt2) :param device: Device for inference (auto-detected if None) :param cache_dir: Cache directory for downloaded files """ # Configuration from params or environment self.local_dir = local_dir or os.environ.get("LOCAL_CHECKPOINT_DIR") self.repo_id = repo_id or os.environ.get("HF_MODEL_REPO", "suyash94/gpt2-fc-adapter") self.base_model = base_model or os.environ.get("BASE_MODEL", "gpt2") if self.local_dir: self.local_dir = Path(self.local_dir) # Set device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) if isinstance(device, str) else device self._model: torch.nn.Module | None = None self._tokenizer: PreTrainedTokenizer | None = None # Load model and tokenizer self._load_model(cache_dir) def _load_model(self, cache_dir: str | None = None) -> None: """Load base model, tokenizer, and LoRA adapter. :param cache_dir: Cache directory for HuggingFace downloads """ # Get adapter path (local or download from Hub) if self.local_dir: print(f"Loading adapter from local: {self.local_dir}") adapter_path = self.local_dir else: print(f"Downloading adapter from {self.repo_id}...") adapter_path = Path( snapshot_download( repo_id=self.repo_id, cache_dir=cache_dir, ) ) # Load tokenizer from adapter (includes special tokens) print(f"Loading tokenizer from adapter...") self._tokenizer = AutoTokenizer.from_pretrained( adapter_path, trust_remote_code=True, ) # Ensure pad token is set if self._tokenizer.pad_token is None: self._tokenizer.pad_token = self._tokenizer.eos_token # Load base model print(f"Loading base model: {self.base_model}...") base_model = AutoModelForCausalLM.from_pretrained( self.base_model, torch_dtype=torch.float32, # CPU-friendly trust_remote_code=True, ) # Resize embeddings if tokenizer has more tokens than model if len(self._tokenizer) > base_model.get_input_embeddings().num_embeddings: print(f"Resizing embeddings: {base_model.get_input_embeddings().num_embeddings} -> {len(self._tokenizer)}") base_model.resize_token_embeddings(len(self._tokenizer)) # Load LoRA adapter print(f"Loading LoRA adapter...") self._model = PeftModel.from_pretrained( base_model, adapter_path, ) # Move to device and set eval mode self._model.to(self.device) self._model.eval() print(f"Model loaded on device: {self.device}") def predict(self, user_query: str, max_new_tokens: int = 128) -> dict[str, Any]: """Generate a function call prediction for a user query. :param user_query: User's natural language command :param max_new_tokens: Maximum new tokens to generate :return: Dict with 'response' and 'parsed' (function call info) """ if self._model is None or self._tokenizer is None: raise RuntimeError("Model not loaded") # Format as chat messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_query}, ] # Apply chat template input_text = self._tokenizer.apply_chat_template(messages, tokenize=False) # Tokenize inputs = self._tokenizer(input_text, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = self._model.generate( **inputs, max_new_tokens=max_new_tokens, pad_token_id=self._tokenizer.pad_token_id, eos_token_id=self._tokenizer.eos_token_id, do_sample=False, # Deterministic ) # Decode response (only the generated part) full_response = self._tokenizer.decode(outputs[0], skip_special_tokens=False) response = full_response[len(input_text):] # Parse function call parsed = parse_function_call(response) return { "response": response, "parsed": parsed, }