| | from smolagents.models import Model, ChatMessage, Tool, MessageRole |
| | from time import sleep |
| | from typing import List, Dict, Optional |
| | from huggingface_hub import hf_hub_download |
| | import json |
| |
|
| |
|
| | class FakeModelReplayLog(Model): |
| | """A model class that returns pre-recorded responses from a log file. |
| | |
| | This class is useful for testing and debugging purposes, as it doesn't make |
| | actual API calls but instead returns responses from a pre-recorded log file. |
| | |
| | Parameters: |
| | log_url (str, optional): |
| | URL to the log file. Defaults to the smolagents example log. |
| | **kwargs: Additional keyword arguments passed to the Model base class. |
| | """ |
| |
|
| | def __init__(self, log_folder: str, **kwargs): |
| | super().__init__(**kwargs) |
| | self.dataset_name = "smolagents/computer-agent-logs" |
| | self.log_folder = log_folder |
| | self.call_counter = 0 |
| | self.model_outputs = self._load_model_outputs() |
| |
|
| | def _load_model_outputs(self) -> List[str]: |
| | """Load model outputs from the log file using HuggingFace datasets library.""" |
| | |
| | file_path = hf_hub_download( |
| | repo_id=self.dataset_name, |
| | filename=self.log_folder + "/metadata.json", |
| | repo_type="dataset", |
| | ) |
| |
|
| | |
| | with open(file_path, "r") as f: |
| | log_data = json.load(f) |
| |
|
| | |
| | model_outputs = [] |
| |
|
| | for step in log_data["summary"][1:]: |
| | model_outputs.append(step["model_output_message"]["content"]) |
| |
|
| | print(f"Loaded {len(model_outputs)} model outputs from log file") |
| | return model_outputs |
| |
|
| | def __call__( |
| | self, |
| | messages: List[Dict[str, str]], |
| | stop_sequences: Optional[List[str]] = None, |
| | grammar: Optional[str] = None, |
| | tools_to_call_from: Optional[List[Tool]] = None, |
| | **kwargs, |
| | ) -> ChatMessage: |
| | """Return the next pre-recorded response from the log file. |
| | |
| | Parameters: |
| | messages: List of input messages (ignored). |
| | stop_sequences: Optional list of stop sequences (ignored). |
| | grammar: Optional grammar specification (ignored). |
| | tools_to_call_from: Optional list of tools (ignored). |
| | **kwargs: Additional keyword arguments (ignored). |
| | |
| | Returns: |
| | ChatMessage: The next pre-recorded response. |
| | """ |
| | sleep(1.0) |
| |
|
| | |
| | if self.call_counter < len(self.model_outputs): |
| | content = self.model_outputs[self.call_counter] |
| | self.call_counter += 1 |
| | else: |
| | content = "No more pre-recorded responses available." |
| |
|
| | |
| | self.last_input_token_count = len(str(messages)) // 4 |
| | self.last_output_token_count = len(content) // 4 |
| |
|
| | |
| | return ChatMessage( |
| | role=MessageRole.ASSISTANT, |
| | content=content, |
| | tool_calls=None, |
| | raw={"source": "pre-recorded log", "call_number": self.call_counter}, |
| | ) |
| |
|