slm-function-calling / inferencer.py
suyash94's picture
Upload folder using huggingface_hub
d46efc1 verified
"""Self-contained inference for SLM Function Calling on HuggingFace Spaces."""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import snapshot_download
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
# System prompt for function calling
SYSTEM_PROMPT = (
"You are a helpful assistant. You have to either provide a way to answer "
"user's request or answer user's query."
)
def parse_function_call(response: str) -> dict[str, Any]:
"""Parse model output to extract function call.
Parses the model's response in the format:
<functioncall> {"name": "...", "arguments": "..."} <|im_end|>
:param response: Raw model output string
:return: Dict with 'fn_name' and 'properties' keys, or 'error' key if parsing fails
"""
try:
# Define delimiters
start_delim = "<functioncall> "
end_delim = "<|im_end|>"
# Find the JSON portion between delimiters
start_idx = response.find(start_delim)
if start_idx == -1:
return {"error": "Start delimiter '<functioncall> ' not found"}
start_idx += len(start_delim)
end_idx = response.find(end_delim, start_idx)
if end_idx == -1:
return {"error": "End delimiter '<|im_end|>' not found"}
# Extract the JSON string
json_str = response[start_idx:end_idx].strip()
# Parse the outer JSON (contains name and arguments)
function_call_dict = json.loads(json_str)
# Extract function name and arguments
fn_name = function_call_dict.get("name")
if fn_name is None:
return {"error": "Function name not found in response"}
arguments_str = function_call_dict.get("arguments", "{}")
# Handle arguments - convert Python-style to JSON-style
if isinstance(arguments_str, str):
# Replace Python boolean/None syntax with JSON syntax
arguments_str = arguments_str.replace("'", '"')
arguments_str = arguments_str.replace("True", "true")
arguments_str = arguments_str.replace("False", "false")
arguments_str = arguments_str.replace("None", "null")
properties = json.loads(arguments_str)
elif isinstance(arguments_str, dict):
properties = arguments_str
else:
properties = {}
return {"fn_name": fn_name, "properties": properties}
except json.JSONDecodeError as e:
return {"error": f"JSON parsing error: {e}"}
except Exception as e:
return {"error": str(e)}
class Inferencer:
"""Inference class for SLM Function Calling model.
Downloads LoRA adapter from HuggingFace Hub on initialization,
or loads from a local directory if specified.
Configuration via environment variables:
- HF_MODEL_REPO: HuggingFace Hub repo ID (e.g., 'username/gpt2-fc-adapter')
- LOCAL_CHECKPOINT_DIR: Local directory path (overrides HF_MODEL_REPO)
- BASE_MODEL: Base model name (default: 'gpt2')
Example::
# Set environment variable
os.environ["HF_MODEL_REPO"] = "suyash94/gpt2-fc-adapter"
inferencer = Inferencer()
result = inferencer.predict("Set the temperature to 22 degrees")
print(result["parsed"]) # {"fn_name": "set_temperature", "properties": {...}}
"""
def __init__(
self,
repo_id: str | None = None,
local_dir: str | Path | None = None,
base_model: str | None = None,
device: torch.device | str | None = None,
cache_dir: str | None = None,
) -> None:
"""Initialize the inferencer.
:param repo_id: HuggingFace Hub repo ID for LoRA adapter
:param local_dir: Local directory containing adapter files
:param base_model: Base model name (default: gpt2)
:param device: Device for inference (auto-detected if None)
:param cache_dir: Cache directory for downloaded files
"""
# Configuration from params or environment
self.local_dir = local_dir or os.environ.get("LOCAL_CHECKPOINT_DIR")
self.repo_id = repo_id or os.environ.get("HF_MODEL_REPO", "suyash94/gpt2-fc-adapter")
self.base_model = base_model or os.environ.get("BASE_MODEL", "gpt2")
if self.local_dir:
self.local_dir = Path(self.local_dir)
# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device) if isinstance(device, str) else device
self._model: torch.nn.Module | None = None
self._tokenizer: PreTrainedTokenizer | None = None
# Load model and tokenizer
self._load_model(cache_dir)
def _load_model(self, cache_dir: str | None = None) -> None:
"""Load base model, tokenizer, and LoRA adapter.
:param cache_dir: Cache directory for HuggingFace downloads
"""
# Get adapter path (local or download from Hub)
if self.local_dir:
print(f"Loading adapter from local: {self.local_dir}")
adapter_path = self.local_dir
else:
print(f"Downloading adapter from {self.repo_id}...")
adapter_path = Path(
snapshot_download(
repo_id=self.repo_id,
cache_dir=cache_dir,
)
)
# Load tokenizer from adapter (includes special tokens)
print(f"Loading tokenizer from adapter...")
self._tokenizer = AutoTokenizer.from_pretrained(
adapter_path,
trust_remote_code=True,
)
# Ensure pad token is set
if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
# Load base model
print(f"Loading base model: {self.base_model}...")
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model,
torch_dtype=torch.float32, # CPU-friendly
trust_remote_code=True,
)
# Resize embeddings if tokenizer has more tokens than model
if len(self._tokenizer) > base_model.get_input_embeddings().num_embeddings:
print(f"Resizing embeddings: {base_model.get_input_embeddings().num_embeddings} -> {len(self._tokenizer)}")
base_model.resize_token_embeddings(len(self._tokenizer))
# Load LoRA adapter
print(f"Loading LoRA adapter...")
self._model = PeftModel.from_pretrained(
base_model,
adapter_path,
)
# Move to device and set eval mode
self._model.to(self.device)
self._model.eval()
print(f"Model loaded on device: {self.device}")
def predict(self, user_query: str, max_new_tokens: int = 128) -> dict[str, Any]:
"""Generate a function call prediction for a user query.
:param user_query: User's natural language command
:param max_new_tokens: Maximum new tokens to generate
:return: Dict with 'response' and 'parsed' (function call info)
"""
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model not loaded")
# Format as chat
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_query},
]
# Apply chat template
input_text = self._tokenizer.apply_chat_template(messages, tokenize=False)
# Tokenize
inputs = self._tokenizer(input_text, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
pad_token_id=self._tokenizer.pad_token_id,
eos_token_id=self._tokenizer.eos_token_id,
do_sample=False, # Deterministic
)
# Decode response (only the generated part)
full_response = self._tokenizer.decode(outputs[0], skip_special_tokens=False)
response = full_response[len(input_text):]
# Parse function call
parsed = parse_function_call(response)
return {
"response": response,
"parsed": parsed,
}