Spaces:
Sleeping
Sleeping
| """Self-contained inference for SLM Function Calling on HuggingFace Spaces.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer | |
| # System prompt for function calling | |
| SYSTEM_PROMPT = ( | |
| "You are a helpful assistant. You have to either provide a way to answer " | |
| "user's request or answer user's query." | |
| ) | |
| def parse_function_call(response: str) -> dict[str, Any]: | |
| """Parse model output to extract function call. | |
| Parses the model's response in the format: | |
| <functioncall> {"name": "...", "arguments": "..."} <|im_end|> | |
| :param response: Raw model output string | |
| :return: Dict with 'fn_name' and 'properties' keys, or 'error' key if parsing fails | |
| """ | |
| try: | |
| # Define delimiters | |
| start_delim = "<functioncall> " | |
| end_delim = "<|im_end|>" | |
| # Find the JSON portion between delimiters | |
| start_idx = response.find(start_delim) | |
| if start_idx == -1: | |
| return {"error": "Start delimiter '<functioncall> ' not found"} | |
| start_idx += len(start_delim) | |
| end_idx = response.find(end_delim, start_idx) | |
| if end_idx == -1: | |
| return {"error": "End delimiter '<|im_end|>' not found"} | |
| # Extract the JSON string | |
| json_str = response[start_idx:end_idx].strip() | |
| # Parse the outer JSON (contains name and arguments) | |
| function_call_dict = json.loads(json_str) | |
| # Extract function name and arguments | |
| fn_name = function_call_dict.get("name") | |
| if fn_name is None: | |
| return {"error": "Function name not found in response"} | |
| arguments_str = function_call_dict.get("arguments", "{}") | |
| # Handle arguments - convert Python-style to JSON-style | |
| if isinstance(arguments_str, str): | |
| # Replace Python boolean/None syntax with JSON syntax | |
| arguments_str = arguments_str.replace("'", '"') | |
| arguments_str = arguments_str.replace("True", "true") | |
| arguments_str = arguments_str.replace("False", "false") | |
| arguments_str = arguments_str.replace("None", "null") | |
| properties = json.loads(arguments_str) | |
| elif isinstance(arguments_str, dict): | |
| properties = arguments_str | |
| else: | |
| properties = {} | |
| return {"fn_name": fn_name, "properties": properties} | |
| except json.JSONDecodeError as e: | |
| return {"error": f"JSON parsing error: {e}"} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| class Inferencer: | |
| """Inference class for SLM Function Calling model. | |
| Downloads LoRA adapter from HuggingFace Hub on initialization, | |
| or loads from a local directory if specified. | |
| Configuration via environment variables: | |
| - HF_MODEL_REPO: HuggingFace Hub repo ID (e.g., 'username/gpt2-fc-adapter') | |
| - LOCAL_CHECKPOINT_DIR: Local directory path (overrides HF_MODEL_REPO) | |
| - BASE_MODEL: Base model name (default: 'gpt2') | |
| Example:: | |
| # Set environment variable | |
| os.environ["HF_MODEL_REPO"] = "suyash94/gpt2-fc-adapter" | |
| inferencer = Inferencer() | |
| result = inferencer.predict("Set the temperature to 22 degrees") | |
| print(result["parsed"]) # {"fn_name": "set_temperature", "properties": {...}} | |
| """ | |
| def __init__( | |
| self, | |
| repo_id: str | None = None, | |
| local_dir: str | Path | None = None, | |
| base_model: str | None = None, | |
| device: torch.device | str | None = None, | |
| cache_dir: str | None = None, | |
| ) -> None: | |
| """Initialize the inferencer. | |
| :param repo_id: HuggingFace Hub repo ID for LoRA adapter | |
| :param local_dir: Local directory containing adapter files | |
| :param base_model: Base model name (default: gpt2) | |
| :param device: Device for inference (auto-detected if None) | |
| :param cache_dir: Cache directory for downloaded files | |
| """ | |
| # Configuration from params or environment | |
| self.local_dir = local_dir or os.environ.get("LOCAL_CHECKPOINT_DIR") | |
| self.repo_id = repo_id or os.environ.get("HF_MODEL_REPO", "suyash94/gpt2-fc-adapter") | |
| self.base_model = base_model or os.environ.get("BASE_MODEL", "gpt2") | |
| if self.local_dir: | |
| self.local_dir = Path(self.local_dir) | |
| # Set device | |
| if device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) if isinstance(device, str) else device | |
| self._model: torch.nn.Module | None = None | |
| self._tokenizer: PreTrainedTokenizer | None = None | |
| # Load model and tokenizer | |
| self._load_model(cache_dir) | |
| def _load_model(self, cache_dir: str | None = None) -> None: | |
| """Load base model, tokenizer, and LoRA adapter. | |
| :param cache_dir: Cache directory for HuggingFace downloads | |
| """ | |
| # Get adapter path (local or download from Hub) | |
| if self.local_dir: | |
| print(f"Loading adapter from local: {self.local_dir}") | |
| adapter_path = self.local_dir | |
| else: | |
| print(f"Downloading adapter from {self.repo_id}...") | |
| adapter_path = Path( | |
| snapshot_download( | |
| repo_id=self.repo_id, | |
| cache_dir=cache_dir, | |
| ) | |
| ) | |
| # Load tokenizer from adapter (includes special tokens) | |
| print(f"Loading tokenizer from adapter...") | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| adapter_path, | |
| trust_remote_code=True, | |
| ) | |
| # Ensure pad token is set | |
| if self._tokenizer.pad_token is None: | |
| self._tokenizer.pad_token = self._tokenizer.eos_token | |
| # Load base model | |
| print(f"Loading base model: {self.base_model}...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| self.base_model, | |
| torch_dtype=torch.float32, # CPU-friendly | |
| trust_remote_code=True, | |
| ) | |
| # Resize embeddings if tokenizer has more tokens than model | |
| if len(self._tokenizer) > base_model.get_input_embeddings().num_embeddings: | |
| print(f"Resizing embeddings: {base_model.get_input_embeddings().num_embeddings} -> {len(self._tokenizer)}") | |
| base_model.resize_token_embeddings(len(self._tokenizer)) | |
| # Load LoRA adapter | |
| print(f"Loading LoRA adapter...") | |
| self._model = PeftModel.from_pretrained( | |
| base_model, | |
| adapter_path, | |
| ) | |
| # Move to device and set eval mode | |
| self._model.to(self.device) | |
| self._model.eval() | |
| print(f"Model loaded on device: {self.device}") | |
| def predict(self, user_query: str, max_new_tokens: int = 128) -> dict[str, Any]: | |
| """Generate a function call prediction for a user query. | |
| :param user_query: User's natural language command | |
| :param max_new_tokens: Maximum new tokens to generate | |
| :return: Dict with 'response' and 'parsed' (function call info) | |
| """ | |
| if self._model is None or self._tokenizer is None: | |
| raise RuntimeError("Model not loaded") | |
| # Format as chat | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_query}, | |
| ] | |
| # Apply chat template | |
| input_text = self._tokenizer.apply_chat_template(messages, tokenize=False) | |
| # Tokenize | |
| inputs = self._tokenizer(input_text, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = self._model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=self._tokenizer.pad_token_id, | |
| eos_token_id=self._tokenizer.eos_token_id, | |
| do_sample=False, # Deterministic | |
| ) | |
| # Decode response (only the generated part) | |
| full_response = self._tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| response = full_response[len(input_text):] | |
| # Parse function call | |
| parsed = parse_function_call(response) | |
| return { | |
| "response": response, | |
| "parsed": parsed, | |
| } | |