| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import importlib.util |
| | import json |
| | import os |
| | import time |
| | from dataclasses import dataclass |
| | from typing import Dict |
| |
|
| | import requests |
| | from huggingface_hub import HfFolder, hf_hub_download, list_spaces |
| |
|
| | from ..models.auto import AutoTokenizer |
| | from ..utils import is_offline_mode, is_openai_available, is_torch_available, logging |
| | from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote |
| | from .prompts import CHAT_MESSAGE_PROMPT, download_prompt |
| | from .python_interpreter import evaluate |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | if is_openai_available(): |
| | import openai |
| |
|
| | if is_torch_available(): |
| | from ..generation import StoppingCriteria, StoppingCriteriaList |
| | from ..models.auto import AutoModelForCausalLM |
| | else: |
| | StoppingCriteria = object |
| |
|
| | _tools_are_initialized = False |
| |
|
| |
|
| | BASE_PYTHON_TOOLS = { |
| | "print": print, |
| | "range": range, |
| | "float": float, |
| | "int": int, |
| | "bool": bool, |
| | "str": str, |
| | } |
| |
|
| |
|
| | @dataclass |
| | class PreTool: |
| | task: str |
| | description: str |
| | repo_id: str |
| |
|
| |
|
| | HUGGINGFACE_DEFAULT_TOOLS = {} |
| |
|
| |
|
| | HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [ |
| | "image-transformation", |
| | "text-download", |
| | "text-to-image", |
| | "text-to-video", |
| | ] |
| |
|
| |
|
| | def get_remote_tools(organization="huggingface-tools"): |
| | if is_offline_mode(): |
| | logger.info("You are in offline mode, so remote tools are not available.") |
| | return {} |
| |
|
| | spaces = list_spaces(author=organization) |
| | tools = {} |
| | for space_info in spaces: |
| | repo_id = space_info.id |
| | resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space") |
| | with open(resolved_config_file, encoding="utf-8") as reader: |
| | config = json.load(reader) |
| |
|
| | task = repo_id.split("/")[-1] |
| | tools[config["name"]] = PreTool(task=task, description=config["description"], repo_id=repo_id) |
| |
|
| | return tools |
| |
|
| |
|
| | def _setup_default_tools(): |
| | global HUGGINGFACE_DEFAULT_TOOLS |
| | global _tools_are_initialized |
| |
|
| | if _tools_are_initialized: |
| | return |
| |
|
| | main_module = importlib.import_module("transformers") |
| | tools_module = main_module.tools |
| |
|
| | remote_tools = get_remote_tools() |
| | for task_name, tool_class_name in TASK_MAPPING.items(): |
| | tool_class = getattr(tools_module, tool_class_name) |
| | description = tool_class.description |
| | HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(task=task_name, description=description, repo_id=None) |
| |
|
| | if not is_offline_mode(): |
| | for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB: |
| | found = False |
| | for tool_name, tool in remote_tools.items(): |
| | if tool.task == task_name: |
| | HUGGINGFACE_DEFAULT_TOOLS[tool_name] = tool |
| | found = True |
| | break |
| |
|
| | if not found: |
| | raise ValueError(f"{task_name} is not implemented on the Hub.") |
| |
|
| | _tools_are_initialized = True |
| |
|
| |
|
| | def resolve_tools(code, toolbox, remote=False, cached_tools=None): |
| | if cached_tools is None: |
| | resolved_tools = BASE_PYTHON_TOOLS.copy() |
| | else: |
| | resolved_tools = cached_tools |
| | for name, tool in toolbox.items(): |
| | if name not in code or name in resolved_tools: |
| | continue |
| |
|
| | if isinstance(tool, Tool): |
| | resolved_tools[name] = tool |
| | else: |
| | task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id |
| | _remote = remote and supports_remote(task_or_repo_id) |
| | resolved_tools[name] = load_tool(task_or_repo_id, remote=_remote) |
| |
|
| | return resolved_tools |
| |
|
| |
|
| | def get_tool_creation_code(code, toolbox, remote=False): |
| | code_lines = ["from transformers import load_tool", ""] |
| | for name, tool in toolbox.items(): |
| | if name not in code or isinstance(tool, Tool): |
| | continue |
| |
|
| | task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id |
| | line = f'{name} = load_tool("{task_or_repo_id}"' |
| | if remote: |
| | line += ", remote=True" |
| | line += ")" |
| | code_lines.append(line) |
| |
|
| | return "\n".join(code_lines) + "\n" |
| |
|
| |
|
| | def clean_code_for_chat(result): |
| | lines = result.split("\n") |
| | idx = 0 |
| | while idx < len(lines) and not lines[idx].lstrip().startswith("```"): |
| | idx += 1 |
| | explanation = "\n".join(lines[:idx]).strip() |
| | if idx == len(lines): |
| | return explanation, None |
| |
|
| | idx += 1 |
| | start_idx = idx |
| | while not lines[idx].lstrip().startswith("```"): |
| | idx += 1 |
| | code = "\n".join(lines[start_idx:idx]).strip() |
| |
|
| | return explanation, code |
| |
|
| |
|
| | def clean_code_for_run(result): |
| | result = f"I will use the following {result}" |
| | explanation, code = result.split("Answer:") |
| | explanation = explanation.strip() |
| | code = code.strip() |
| |
|
| | code_lines = code.split("\n") |
| | if code_lines[0] in ["```", "```py", "```python"]: |
| | code_lines = code_lines[1:] |
| | if code_lines[-1] == "```": |
| | code_lines = code_lines[:-1] |
| | code = "\n".join(code_lines) |
| |
|
| | return explanation, code |
| |
|
| |
|
| | class Agent: |
| | """ |
| | Base class for all agents which contains the main API methods. |
| | |
| | Args: |
| | chat_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `chat_prompt_template.txt` in this repo in this case. |
| | run_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `run_prompt_template.txt` in this repo in this case. |
| | additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
| | Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
| | one of the default tools, that default tool will be overridden. |
| | """ |
| |
|
| | def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None): |
| | _setup_default_tools() |
| |
|
| | agent_name = self.__class__.__name__ |
| | self.chat_prompt_template = download_prompt(chat_prompt_template, agent_name, mode="chat") |
| | self.run_prompt_template = download_prompt(run_prompt_template, agent_name, mode="run") |
| | self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy() |
| | self.log = print |
| | if additional_tools is not None: |
| | if isinstance(additional_tools, (list, tuple)): |
| | additional_tools = {t.name: t for t in additional_tools} |
| | elif not isinstance(additional_tools, dict): |
| | additional_tools = {additional_tools.name: additional_tools} |
| |
|
| | replacements = {name: tool for name, tool in additional_tools.items() if name in HUGGINGFACE_DEFAULT_TOOLS} |
| | self._toolbox.update(additional_tools) |
| | if len(replacements) > 1: |
| | names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()]) |
| | logger.warning( |
| | f"The following tools have been replaced by the ones provided in `additional_tools`:\n{names}." |
| | ) |
| | elif len(replacements) == 1: |
| | name = list(replacements.keys())[0] |
| | logger.warning(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.") |
| |
|
| | self.prepare_for_new_chat() |
| |
|
| | @property |
| | def toolbox(self) -> Dict[str, Tool]: |
| | """Get all tool currently available to the agent""" |
| | return self._toolbox |
| |
|
| | def format_prompt(self, task, chat_mode=False): |
| | description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()]) |
| | if chat_mode: |
| | if self.chat_history is None: |
| | prompt = self.chat_prompt_template.replace("<<all_tools>>", description) |
| | else: |
| | prompt = self.chat_history |
| | prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task) |
| | else: |
| | prompt = self.run_prompt_template.replace("<<all_tools>>", description) |
| | prompt = prompt.replace("<<prompt>>", task) |
| | return prompt |
| |
|
| | def set_stream(self, streamer): |
| | """ |
| | Set the function use to stream results (which is `print` by default). |
| | |
| | Args: |
| | streamer (`callable`): The function to call when streaming results from the LLM. |
| | """ |
| | self.log = streamer |
| |
|
| | def chat(self, task, *, return_code=False, remote=False, **kwargs): |
| | """ |
| | Sends a new request to the agent in a chat. Will use the previous ones in its history. |
| | |
| | Args: |
| | task (`str`): The task to perform |
| | return_code (`bool`, *optional*, defaults to `False`): |
| | Whether to just return code and not evaluate it. |
| | remote (`bool`, *optional*, defaults to `False`): |
| | Whether or not to use remote tools (inference endpoints) instead of local ones. |
| | kwargs (additional keyword arguments, *optional*): |
| | Any keyword argument to send to the agent when evaluating the code. |
| | |
| | Example: |
| | |
| | ```py |
| | from transformers import HfAgent |
| | |
| | agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder") |
| | agent.chat("Draw me a picture of rivers and lakes") |
| | |
| | agent.chat("Transform the picture so that there is a rock in there") |
| | ``` |
| | """ |
| | prompt = self.format_prompt(task, chat_mode=True) |
| | result = self.generate_one(prompt, stop=["Human:", "====="]) |
| | self.chat_history = prompt + result.strip() + "\n" |
| | explanation, code = clean_code_for_chat(result) |
| |
|
| | self.log(f"==Explanation from the agent==\n{explanation}") |
| |
|
| | if code is not None: |
| | self.log(f"\n\n==Code generated by the agent==\n{code}") |
| | if not return_code: |
| | self.log("\n\n==Result==") |
| | self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools) |
| | self.chat_state.update(kwargs) |
| | return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True) |
| | else: |
| | tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) |
| | return f"{tool_code}\n{code}" |
| |
|
| | def prepare_for_new_chat(self): |
| | """ |
| | Clears the history of prior calls to [`~Agent.chat`]. |
| | """ |
| | self.chat_history = None |
| | self.chat_state = {} |
| | self.cached_tools = None |
| |
|
| | def run(self, task, *, return_code=False, remote=False, **kwargs): |
| | """ |
| | Sends a request to the agent. |
| | |
| | Args: |
| | task (`str`): The task to perform |
| | return_code (`bool`, *optional*, defaults to `False`): |
| | Whether to just return code and not evaluate it. |
| | remote (`bool`, *optional*, defaults to `False`): |
| | Whether or not to use remote tools (inference endpoints) instead of local ones. |
| | kwargs (additional keyword arguments, *optional*): |
| | Any keyword argument to send to the agent when evaluating the code. |
| | |
| | Example: |
| | |
| | ```py |
| | from transformers import HfAgent |
| | |
| | agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder") |
| | agent.run("Draw me a picture of rivers and lakes") |
| | ``` |
| | """ |
| | prompt = self.format_prompt(task) |
| | result = self.generate_one(prompt, stop=["Task:"]) |
| | explanation, code = clean_code_for_run(result) |
| |
|
| | self.log(f"==Explanation from the agent==\n{explanation}") |
| |
|
| | self.log(f"\n\n==Code generated by the agent==\n{code}") |
| | if not return_code: |
| | self.log("\n\n==Result==") |
| | self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools) |
| | return evaluate(code, self.cached_tools, state=kwargs.copy()) |
| | else: |
| | tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) |
| | return f"{tool_code}\n{code}" |
| |
|
| | def generate_one(self, prompt, stop): |
| | |
| | raise NotImplementedError |
| |
|
| | def generate_many(self, prompts, stop): |
| | |
| | return [self.generate_one(prompt, stop) for prompt in prompts] |
| |
|
| |
|
| | class OpenAiAgent(Agent): |
| | """ |
| | Agent that uses the openai API to generate code. |
| | |
| | <Tip warning={true}> |
| | |
| | The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like |
| | `"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version. |
| | |
| | </Tip> |
| | |
| | Args: |
| | model (`str`, *optional*, defaults to `"text-davinci-003"`): |
| | The name of the OpenAI model to use. |
| | api_key (`str`, *optional*): |
| | The API key to use. If unset, will look for the environment variable `"OPENAI_API_KEY"`. |
| | chat_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `chat_prompt_template.txt` in this repo in this case. |
| | run_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `run_prompt_template.txt` in this repo in this case. |
| | additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
| | Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
| | one of the default tools, that default tool will be overridden. |
| | |
| | Example: |
| | |
| | ```py |
| | from transformers import OpenAiAgent |
| | |
| | agent = OpenAiAgent(model="text-davinci-003", api_key=xxx) |
| | agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!") |
| | ``` |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model="text-davinci-003", |
| | api_key=None, |
| | chat_prompt_template=None, |
| | run_prompt_template=None, |
| | additional_tools=None, |
| | ): |
| | if not is_openai_available(): |
| | raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.") |
| |
|
| | if api_key is None: |
| | api_key = os.environ.get("OPENAI_API_KEY", None) |
| | if api_key is None: |
| | raise ValueError( |
| | "You need an openai key to use `OpenAIAgent`. You can get one here: Get one here " |
| | "https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = " |
| | "xxx." |
| | ) |
| | else: |
| | openai.api_key = api_key |
| | self.model = model |
| | super().__init__( |
| | chat_prompt_template=chat_prompt_template, |
| | run_prompt_template=run_prompt_template, |
| | additional_tools=additional_tools, |
| | ) |
| |
|
| | def generate_many(self, prompts, stop): |
| | if "gpt" in self.model: |
| | return [self._chat_generate(prompt, stop) for prompt in prompts] |
| | else: |
| | return self._completion_generate(prompts, stop) |
| |
|
| | def generate_one(self, prompt, stop): |
| | if "gpt" in self.model: |
| | return self._chat_generate(prompt, stop) |
| | else: |
| | return self._completion_generate([prompt], stop)[0] |
| |
|
| | def _chat_generate(self, prompt, stop): |
| | result = openai.ChatCompletion.create( |
| | model=self.model, |
| | messages=[{"role": "user", "content": prompt}], |
| | temperature=0, |
| | stop=stop, |
| | ) |
| | return result["choices"][0]["message"]["content"] |
| |
|
| | def _completion_generate(self, prompts, stop): |
| | result = openai.Completion.create( |
| | model=self.model, |
| | prompt=prompts, |
| | temperature=0, |
| | stop=stop, |
| | max_tokens=200, |
| | ) |
| | return [answer["text"] for answer in result["choices"]] |
| |
|
| |
|
| | class AzureOpenAiAgent(Agent): |
| | """ |
| | Agent that uses Azure OpenAI to generate code. See the [official |
| | documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI |
| | model on Azure |
| | |
| | <Tip warning={true}> |
| | |
| | The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like |
| | `"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version. |
| | |
| | </Tip> |
| | |
| | Args: |
| | deployment_id (`str`): |
| | The name of the deployed Azure openAI model to use. |
| | api_key (`str`, *optional*): |
| | The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`. |
| | resource_name (`str`, *optional*): |
| | The name of your Azure OpenAI Resource. If unset, will look for the environment variable |
| | `"AZURE_OPENAI_RESOURCE_NAME"`. |
| | api_version (`str`, *optional*, default to `"2022-12-01"`): |
| | The API version to use for this agent. |
| | is_chat_mode (`bool`, *optional*): |
| | Whether you are using a completion model or a chat model (see note above, chat models won't be as |
| | efficient). Will default to `gpt` being in the `deployment_id` or not. |
| | chat_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `chat_prompt_template.txt` in this repo in this case. |
| | run_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `run_prompt_template.txt` in this repo in this case. |
| | additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
| | Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
| | one of the default tools, that default tool will be overridden. |
| | |
| | Example: |
| | |
| | ```py |
| | from transformers import AzureOpenAiAgent |
| | |
| | agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy) |
| | agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!") |
| | ``` |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | deployment_id, |
| | api_key=None, |
| | resource_name=None, |
| | api_version="2022-12-01", |
| | is_chat_model=None, |
| | chat_prompt_template=None, |
| | run_prompt_template=None, |
| | additional_tools=None, |
| | ): |
| | if not is_openai_available(): |
| | raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.") |
| |
|
| | self.deployment_id = deployment_id |
| | openai.api_type = "azure" |
| | if api_key is None: |
| | api_key = os.environ.get("AZURE_OPENAI_API_KEY", None) |
| | if api_key is None: |
| | raise ValueError( |
| | "You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with " |
| | "`os.environ['AZURE_OPENAI_API_KEY'] = xxx." |
| | ) |
| | else: |
| | openai.api_key = api_key |
| | if resource_name is None: |
| | resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME", None) |
| | if resource_name is None: |
| | raise ValueError( |
| | "You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with " |
| | "`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx." |
| | ) |
| | else: |
| | openai.api_base = f"https://{resource_name}.openai.azure.com" |
| | openai.api_version = api_version |
| |
|
| | if is_chat_model is None: |
| | is_chat_model = "gpt" in deployment_id.lower() |
| | self.is_chat_model = is_chat_model |
| |
|
| | super().__init__( |
| | chat_prompt_template=chat_prompt_template, |
| | run_prompt_template=run_prompt_template, |
| | additional_tools=additional_tools, |
| | ) |
| |
|
| | def generate_many(self, prompts, stop): |
| | if self.is_chat_model: |
| | return [self._chat_generate(prompt, stop) for prompt in prompts] |
| | else: |
| | return self._completion_generate(prompts, stop) |
| |
|
| | def generate_one(self, prompt, stop): |
| | if self.is_chat_model: |
| | return self._chat_generate(prompt, stop) |
| | else: |
| | return self._completion_generate([prompt], stop)[0] |
| |
|
| | def _chat_generate(self, prompt, stop): |
| | result = openai.ChatCompletion.create( |
| | engine=self.deployment_id, |
| | messages=[{"role": "user", "content": prompt}], |
| | temperature=0, |
| | stop=stop, |
| | ) |
| | return result["choices"][0]["message"]["content"] |
| |
|
| | def _completion_generate(self, prompts, stop): |
| | result = openai.Completion.create( |
| | engine=self.deployment_id, |
| | prompt=prompts, |
| | temperature=0, |
| | stop=stop, |
| | max_tokens=200, |
| | ) |
| | return [answer["text"] for answer in result["choices"]] |
| |
|
| |
|
| | class HfAgent(Agent): |
| | """ |
| | Agent that uses an inference endpoint to generate code. |
| | |
| | Args: |
| | url_endpoint (`str`): |
| | The name of the url endpoint to use. |
| | token (`str`, *optional*): |
| | The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when |
| | running `huggingface-cli login` (stored in `~/.huggingface`). |
| | chat_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `chat_prompt_template.txt` in this repo in this case. |
| | run_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `run_prompt_template.txt` in this repo in this case. |
| | additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
| | Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
| | one of the default tools, that default tool will be overridden. |
| | |
| | Example: |
| | |
| | ```py |
| | from transformers import HfAgent |
| | |
| | agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder") |
| | agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!") |
| | ``` |
| | """ |
| |
|
| | def __init__( |
| | self, url_endpoint, token=None, chat_prompt_template=None, run_prompt_template=None, additional_tools=None |
| | ): |
| | self.url_endpoint = url_endpoint |
| | if token is None: |
| | self.token = f"Bearer {HfFolder().get_token()}" |
| | elif token.startswith("Bearer") or token.startswith("Basic"): |
| | self.token = token |
| | else: |
| | self.token = f"Bearer {token}" |
| | super().__init__( |
| | chat_prompt_template=chat_prompt_template, |
| | run_prompt_template=run_prompt_template, |
| | additional_tools=additional_tools, |
| | ) |
| |
|
| | def generate_one(self, prompt, stop): |
| | headers = {"Authorization": self.token} |
| | inputs = { |
| | "inputs": prompt, |
| | "parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop}, |
| | } |
| |
|
| | response = requests.post(self.url_endpoint, json=inputs, headers=headers) |
| | if response.status_code == 429: |
| | logger.info("Getting rate-limited, waiting a tiny bit before trying again.") |
| | time.sleep(1) |
| | return self._generate_one(prompt) |
| | elif response.status_code != 200: |
| | raise ValueError(f"Error {response.status_code}: {response.json()}") |
| |
|
| | result = response.json()[0]["generated_text"] |
| | |
| | for stop_seq in stop: |
| | if result.endswith(stop_seq): |
| | return result[: -len(stop_seq)] |
| | return result |
| |
|
| |
|
| | class LocalAgent(Agent): |
| | """ |
| | Agent that uses a local model and tokenizer to generate code. |
| | |
| | Args: |
| | model ([`PreTrainedModel`]): |
| | The model to use for the agent. |
| | tokenizer ([`PreTrainedTokenizer`]): |
| | The tokenizer to use for the agent. |
| | chat_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `chat_prompt_template.txt` in this repo in this case. |
| | run_prompt_template (`str`, *optional*): |
| | Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
| | actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
| | `run_prompt_template.txt` in this repo in this case. |
| | additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
| | Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
| | one of the default tools, that default tool will be overridden. |
| | |
| | Example: |
| | |
| | ```py |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent |
| | |
| | checkpoint = "bigcode/starcoder" |
| | model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16) |
| | tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
| | |
| | agent = LocalAgent(model, tokenizer) |
| | agent.run("Draw me a picture of rivers and lakes.") |
| | ``` |
| | """ |
| |
|
| | def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None): |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | super().__init__( |
| | chat_prompt_template=chat_prompt_template, |
| | run_prompt_template=run_prompt_template, |
| | additional_tools=additional_tools, |
| | ) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| | """ |
| | Convenience method to build a `LocalAgent` from a pretrained checkpoint. |
| | |
| | Args: |
| | pretrained_model_name_or_path (`str` or `os.PathLike`): |
| | The name of a repo on the Hub or a local path to a folder containing both model and tokenizer. |
| | kwargs (`Dict[str, Any]`, *optional*): |
| | Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`]. |
| | |
| | Example: |
| | |
| | ```py |
| | import torch |
| | from transformers import LocalAgent |
| | |
| | agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16) |
| | agent.run("Draw me a picture of rivers and lakes.") |
| | ``` |
| | """ |
| | model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs) |
| | tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) |
| | return cls(model, tokenizer) |
| |
|
| | @property |
| | def _model_device(self): |
| | if hasattr(self.model, "hf_device_map"): |
| | return list(self.model.hf_device_map.values())[0] |
| | for param in self.model.parameters(): |
| | return param.device |
| |
|
| | def generate_one(self, prompt, stop): |
| | encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device) |
| | src_len = encoded_inputs["input_ids"].shape[1] |
| | stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)]) |
| | outputs = self.model.generate( |
| | encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria |
| | ) |
| |
|
| | result = self.tokenizer.decode(outputs[0].tolist()[src_len:]) |
| | |
| | for stop_seq in stop: |
| | if result.endswith(stop_seq): |
| | result = result[: -len(stop_seq)] |
| | return result |
| |
|
| |
|
| | class StopSequenceCriteria(StoppingCriteria): |
| | """ |
| | This class can be used to stop generation whenever a sequence of tokens is encountered. |
| | |
| | Args: |
| | stop_sequences (`str` or `List[str]`): |
| | The sequence (or list of sequences) on which to stop execution. |
| | tokenizer: |
| | The tokenizer used to decode the model outputs. |
| | """ |
| |
|
| | def __init__(self, stop_sequences, tokenizer): |
| | if isinstance(stop_sequences, str): |
| | stop_sequences = [stop_sequences] |
| | self.stop_sequences = stop_sequences |
| | self.tokenizer = tokenizer |
| |
|
| | def __call__(self, input_ids, scores, **kwargs) -> bool: |
| | decoded_output = self.tokenizer.decode(input_ids.tolist()[0]) |
| | return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences) |
| |
|