Spaces:
Runtime error
Runtime error
File size: 4,799 Bytes
ed83d97 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import logging
import os
from typing import Any, Optional
import backoff
from dsp.modules.lm import LM
try:
import anthropic
anthropic_rate_limit = anthropic.RateLimitError
except ImportError:
anthropic_rate_limit = Exception
logger = logging.getLogger(__name__)
BASE_URL = "https://api.anthropic.com/v1/messages"
def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(
"Backing off {wait:0.1f} seconds after {tries} tries "
"calling function {target} with kwargs "
"{kwargs}".format(**details),
)
def giveup_hdlr(details):
"""wrapper function that decides when to give up on retry"""
if "rate limits" in details.message:
return False
return True
class Claude(LM):
"""Wrapper around anthropic's API. Supports both the Anthropic and Azure APIs."""
def __init__(
self,
model: str = "claude-instant-1.2",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs,
):
super().__init__(model)
try:
from anthropic import Anthropic, RateLimitError
except ImportError as err:
raise ImportError("Claude requires `pip install anthropic`.") from err
self.provider = "anthropic"
self.api_key = api_key = (
os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key
)
self.api_base = BASE_URL if api_base is None else api_base
self.kwargs = {
"temperature": 0.0
if "temperature" not in kwargs
else kwargs["temperature"],
"max_tokens": min(kwargs.get("max_tokens", 4096), 4096),
"top_p": 1.0 if "top_p" not in kwargs else kwargs["top_p"],
"top_k": 1 if "top_k" not in kwargs else kwargs["top_k"],
"n": kwargs.pop("n", kwargs.pop("num_generations", 1)),
**kwargs,
}
self.kwargs["model"] = model
self.history: list[dict[str, Any]] = []
self.client = Anthropic(api_key=api_key)
def log_usage(self, response):
"""Log the total tokens from the Anthropic API response."""
usage_data = response.usage
if usage_data:
total_tokens = usage_data.input_tokens + usage_data.output_tokens
logger.info(f"{total_tokens}")
def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs
kwargs = {**self.kwargs, **kwargs}
# caching mechanism requires hashable kwargs
kwargs["messages"] = [{"role": "user", "content": prompt}]
kwargs.pop("n")
print(kwargs)
response = self.client.messages.create(**kwargs)
history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)
return response
@backoff.on_exception(
backoff.expo,
(anthropic_rate_limit),
max_time=1000,
max_tries=8,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
# def request(self, prompt: str, **kwargs):
# """Handles retrieval of completions from Anthropic whilst handling API errors"""
# return self.basic_request(prompt, **kwargs)
def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
"""Retrieves completions from Anthropic.
Args:
prompt (str): prompt to send to Anthropic
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.
Returns:
list[str]: list of completion choices
"""
assert only_completed, "for now"
assert return_sorted is False, "for now"
# per eg here: https://docs.anthropic.com/claude/reference/messages-examples
# max tokens can be used as a proxy to return smaller responses
# so this cannot be a proper indicator for incomplete response unless it isnt the user-intent.
# if only_completed and response.stop_reason != "end_turn":
# choices = []
n = kwargs.pop("n", 1)
completions = []
for i in range(n):
response = self.basic_request(prompt, **kwargs)
# TODO: Log llm usage instead of hardcoded openai usage
# if dsp.settings.log_openai_usage:
# self.log_usage(response)
if only_completed and response.stop_reason == "max_tokens":
continue
completions = [c.text for c in response.content]
return completions
|