File size: 4,799 Bytes
ed83d97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import logging
import os
from typing import Any, Optional

import backoff

from dsp.modules.lm import LM

try:
    import anthropic

    anthropic_rate_limit = anthropic.RateLimitError
except ImportError:
    anthropic_rate_limit = Exception


logger = logging.getLogger(__name__)

BASE_URL = "https://api.anthropic.com/v1/messages"


def backoff_hdlr(details):
    """Handler from https://pypi.org/project/backoff/"""
    print(
        "Backing off {wait:0.1f} seconds after {tries} tries "
        "calling function {target} with kwargs "
        "{kwargs}".format(**details),
    )


def giveup_hdlr(details):
    """wrapper function that decides when to give up on retry"""
    if "rate limits" in details.message:
        return False
    return True


class Claude(LM):
    """Wrapper around anthropic's API. Supports both the Anthropic and Azure APIs."""

    def __init__(
        self,
        model: str = "claude-instant-1.2",
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
        **kwargs,
    ):
        super().__init__(model)

        try:
            from anthropic import Anthropic, RateLimitError
        except ImportError as err:
            raise ImportError("Claude requires `pip install anthropic`.") from err

        self.provider = "anthropic"
        self.api_key = api_key = (
            os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key
        )
        self.api_base = BASE_URL if api_base is None else api_base

        self.kwargs = {
            "temperature": 0.0
            if "temperature" not in kwargs
            else kwargs["temperature"],
            "max_tokens": min(kwargs.get("max_tokens", 4096), 4096),
            "top_p": 1.0 if "top_p" not in kwargs else kwargs["top_p"],
            "top_k": 1 if "top_k" not in kwargs else kwargs["top_k"],
            "n": kwargs.pop("n", kwargs.pop("num_generations", 1)),
            **kwargs,
        }
        self.kwargs["model"] = model
        self.history: list[dict[str, Any]] = []
        self.client = Anthropic(api_key=api_key)

    def log_usage(self, response):
        """Log the total tokens from the Anthropic API response."""
        usage_data = response.usage
        if usage_data:
            total_tokens = usage_data.input_tokens + usage_data.output_tokens
            logger.info(f"{total_tokens}")

    def basic_request(self, prompt: str, **kwargs):
        raw_kwargs = kwargs

        kwargs = {**self.kwargs, **kwargs}
        # caching mechanism requires hashable kwargs
        kwargs["messages"] = [{"role": "user", "content": prompt}]
        kwargs.pop("n")
        print(kwargs)
        response = self.client.messages.create(**kwargs)

        history = {
            "prompt": prompt,
            "response": response,
            "kwargs": kwargs,
            "raw_kwargs": raw_kwargs,
        }
        self.history.append(history)

        return response

    @backoff.on_exception(
        backoff.expo,
        (anthropic_rate_limit),
        max_time=1000,
        max_tries=8,
        on_backoff=backoff_hdlr,
        giveup=giveup_hdlr,
    )
    # def request(self, prompt: str, **kwargs):
    #     """Handles retrieval of completions from Anthropic whilst handling API errors"""
    #     return self.basic_request(prompt, **kwargs)

    def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
        """Retrieves completions from Anthropic.

        Args:
            prompt (str): prompt to send to Anthropic
            only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
            return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.

        Returns:
            list[str]: list of completion choices
        """

        assert only_completed, "for now"
        assert return_sorted is False, "for now"

        # per eg here: https://docs.anthropic.com/claude/reference/messages-examples
        # max tokens can be used as a proxy to return smaller responses
        # so this cannot be a proper indicator for incomplete response unless it isnt the user-intent.
        # if only_completed and response.stop_reason != "end_turn":
        #     choices = []

        n = kwargs.pop("n", 1)
        completions = []
        for i in range(n):
            response = self.basic_request(prompt, **kwargs)
            # TODO: Log llm usage instead of hardcoded openai usage
            # if dsp.settings.log_openai_usage:
            #     self.log_usage(response)
            if only_completed and response.stop_reason == "max_tokens":
                continue
            completions = [c.text for c in response.content]
        return completions