Adrien commited on
Commit
ed83d97
·
1 Parent(s): 428cc4d

fix claude

Browse files
Files changed (3) hide show
  1. Claude.py +145 -0
  2. app.py +3 -3
  3. writer.py +1 -1
Claude.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Any, Optional
4
+
5
+ import backoff
6
+
7
+ from dsp.modules.lm import LM
8
+
9
+ try:
10
+ import anthropic
11
+
12
+ anthropic_rate_limit = anthropic.RateLimitError
13
+ except ImportError:
14
+ anthropic_rate_limit = Exception
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ BASE_URL = "https://api.anthropic.com/v1/messages"
20
+
21
+
22
+ def backoff_hdlr(details):
23
+ """Handler from https://pypi.org/project/backoff/"""
24
+ print(
25
+ "Backing off {wait:0.1f} seconds after {tries} tries "
26
+ "calling function {target} with kwargs "
27
+ "{kwargs}".format(**details),
28
+ )
29
+
30
+
31
+ def giveup_hdlr(details):
32
+ """wrapper function that decides when to give up on retry"""
33
+ if "rate limits" in details.message:
34
+ return False
35
+ return True
36
+
37
+
38
+ class Claude(LM):
39
+ """Wrapper around anthropic's API. Supports both the Anthropic and Azure APIs."""
40
+
41
+ def __init__(
42
+ self,
43
+ model: str = "claude-instant-1.2",
44
+ api_key: Optional[str] = None,
45
+ api_base: Optional[str] = None,
46
+ **kwargs,
47
+ ):
48
+ super().__init__(model)
49
+
50
+ try:
51
+ from anthropic import Anthropic, RateLimitError
52
+ except ImportError as err:
53
+ raise ImportError("Claude requires `pip install anthropic`.") from err
54
+
55
+ self.provider = "anthropic"
56
+ self.api_key = api_key = (
57
+ os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key
58
+ )
59
+ self.api_base = BASE_URL if api_base is None else api_base
60
+
61
+ self.kwargs = {
62
+ "temperature": 0.0
63
+ if "temperature" not in kwargs
64
+ else kwargs["temperature"],
65
+ "max_tokens": min(kwargs.get("max_tokens", 4096), 4096),
66
+ "top_p": 1.0 if "top_p" not in kwargs else kwargs["top_p"],
67
+ "top_k": 1 if "top_k" not in kwargs else kwargs["top_k"],
68
+ "n": kwargs.pop("n", kwargs.pop("num_generations", 1)),
69
+ **kwargs,
70
+ }
71
+ self.kwargs["model"] = model
72
+ self.history: list[dict[str, Any]] = []
73
+ self.client = Anthropic(api_key=api_key)
74
+
75
+ def log_usage(self, response):
76
+ """Log the total tokens from the Anthropic API response."""
77
+ usage_data = response.usage
78
+ if usage_data:
79
+ total_tokens = usage_data.input_tokens + usage_data.output_tokens
80
+ logger.info(f"{total_tokens}")
81
+
82
+ def basic_request(self, prompt: str, **kwargs):
83
+ raw_kwargs = kwargs
84
+
85
+ kwargs = {**self.kwargs, **kwargs}
86
+ # caching mechanism requires hashable kwargs
87
+ kwargs["messages"] = [{"role": "user", "content": prompt}]
88
+ kwargs.pop("n")
89
+ print(kwargs)
90
+ response = self.client.messages.create(**kwargs)
91
+
92
+ history = {
93
+ "prompt": prompt,
94
+ "response": response,
95
+ "kwargs": kwargs,
96
+ "raw_kwargs": raw_kwargs,
97
+ }
98
+ self.history.append(history)
99
+
100
+ return response
101
+
102
+ @backoff.on_exception(
103
+ backoff.expo,
104
+ (anthropic_rate_limit),
105
+ max_time=1000,
106
+ max_tries=8,
107
+ on_backoff=backoff_hdlr,
108
+ giveup=giveup_hdlr,
109
+ )
110
+ # def request(self, prompt: str, **kwargs):
111
+ # """Handles retrieval of completions from Anthropic whilst handling API errors"""
112
+ # return self.basic_request(prompt, **kwargs)
113
+
114
+ def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
115
+ """Retrieves completions from Anthropic.
116
+
117
+ Args:
118
+ prompt (str): prompt to send to Anthropic
119
+ only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
120
+ return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.
121
+
122
+ Returns:
123
+ list[str]: list of completion choices
124
+ """
125
+
126
+ assert only_completed, "for now"
127
+ assert return_sorted is False, "for now"
128
+
129
+ # per eg here: https://docs.anthropic.com/claude/reference/messages-examples
130
+ # max tokens can be used as a proxy to return smaller responses
131
+ # so this cannot be a proper indicator for incomplete response unless it isnt the user-intent.
132
+ # if only_completed and response.stop_reason != "end_turn":
133
+ # choices = []
134
+
135
+ n = kwargs.pop("n", 1)
136
+ completions = []
137
+ for i in range(n):
138
+ response = self.basic_request(prompt, **kwargs)
139
+ # TODO: Log llm usage instead of hardcoded openai usage
140
+ # if dsp.settings.log_openai_usage:
141
+ # self.log_usage(response)
142
+ if only_completed and response.stop_reason == "max_tokens":
143
+ continue
144
+ completions = [c.text for c in response.content]
145
+ return completions
app.py CHANGED
@@ -2,12 +2,12 @@ import streamlit as st
2
  from writer import write_article, incorporate_feedback, _template, evaluate_post
3
  import hmac
4
  import dspy
5
- from dsp.modules import Claude
6
 
7
  import phoenix as px
8
 
9
- my_traces = px.Client().get_trace_dataset().save()
10
- px.launch_app(trace=px.TraceDataset.load(my_traces))
11
 
12
 
13
  def check_password():
 
2
  from writer import write_article, incorporate_feedback, _template, evaluate_post
3
  import hmac
4
  import dspy
5
+ from Claude import Claude
6
 
7
  import phoenix as px
8
 
9
+
10
+ px.launch_app()
11
 
12
 
13
  def check_password():
writer.py CHANGED
@@ -10,7 +10,7 @@ from opentelemetry.sdk import trace as trace_sdk
10
  from opentelemetry.sdk.resources import Resource
11
  from opentelemetry.sdk.trace.export import SimpleSpanProcessor
12
 
13
- endpoint = "http://127.0.0.1:6006/v1/traces"
14
  resource = Resource(attributes={})
15
  tracer_provider = trace_sdk.TracerProvider(resource=resource)
16
  span_otlp_exporter = OTLPSpanExporter(endpoint=endpoint)
 
10
  from opentelemetry.sdk.resources import Resource
11
  from opentelemetry.sdk.trace.export import SimpleSpanProcessor
12
 
13
+ endpoint = "http://127.0.0.1:7860/v1/traces"
14
  resource = Resource(attributes={})
15
  tracer_provider = trace_sdk.TracerProvider(resource=resource)
16
  span_otlp_exporter = OTLPSpanExporter(endpoint=endpoint)