File size: 5,190 Bytes
7eb1167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
# Usage

Initialize your client as follows:

```python

import tracking_aws

bedrock_runtime = tracking_aws.new_default_client()
```

Your usage will be tracked in a local file called `usage_nrel_aws.json`.
"""
from __future__ import annotations
from math import nan
import itertools
import os
import io
import json
import pathlib
from typing import Union, Dict
import pathlib
from typing import Any
from contextlib import contextmanager

import boto3

Usage = Dict[str, Union[int, float]]  # TODO: could have used Counter class
default_usage_file = pathlib.Path("usage_nrel_aws.json")

CLAUDE_3_5_HAIKU = 'arn:aws:bedrock:us-west-2:991404956194:application-inference-profile/g47vfd2xvs5w'
CLAUDE_3_5_SONNET = 'arn:aws:bedrock:us-west-2:991404956194:application-inference-profile/56i8iq1vib3e'

# prices match https://aws.amazon.com/bedrock/pricing/ as of January 2025.
# but they don't consider discounts for caching or batching.
# Not all models are listed in this file, nor is the fine-tuning API.
pricing = {  # $ per 1,000 tokens
    CLAUDE_3_5_HAIKU: {'input': 0.0008, 'output': 0.004},
    CLAUDE_3_5_SONNET: {'input': 0.003, 'output': 0.015},
}

# Default models.  These variables can be imported from this module.
# Even if the system that is being evaluated uses a cheap default_model.
# one might want to evaluate it carefully using a more expensive default_eval_model.
default_model = CLAUDE_3_5_HAIKU
default_eval_model = CLAUDE_3_5_HAIKU


# A context manager that lets you temporarily change the default models
# during a block of code.  You can write things like
#     with use_model('arn:aws:bedrock:us-west-2:991404956194:application-inference-profile/g47vfd2xvs5w'):
#        ...
#
#     with use_model(eval_model='arn:aws:bedrock:us-west-2:991404956194:application-inference-profile/g47vfd2xvs5w'):
#        ...
@contextmanager
def use_model(model: str = default_model, eval_model: str = default_eval_model):
    global default_model, default_eval_model
    save_model, save_eval_model = default_model, default_eval_model
    default_model, default_eval_model = model, eval_model
    try:
        yield
    finally:
        default_model, default_eval_model = save_model, save_eval_model


def track_usage(client: boto3.client, path: pathlib.Path = default_usage_file) -> boto3.client:
    """
    This method modifies (and returns) `client` so that its API calls
    will log token counts to `path`. If the file does not exist it
    will be created after the first API call. If the file exists the new
    counts will be added to it.

    The `read_usage()` function gets a Usage object from the file, e.g.:
    {
        "cost": 0.0022136,
        "input_tokens": 16,
        "output_tokens": 272
    }

    >>> client = boto3.client('bedrock')
    >>> track_usage(client, "example_usage_file.json")
    >>> type(client)
    <class 'botocore.client.BaseClient'>

    """
    old_invoke_model = client.invoke_model

    def tracked_invoke_model(*args, **kwargs) -> Any:
        response = old_invoke_model(*args, **kwargs)
        old: Usage = read_usage(path)
        new, response_body = get_usage(response, model=kwargs.get('modelId', None))
        _write_usage(_merge_usage(old, new), path)
        return response_body

    client.invoke_model = tracked_invoke_model  # type:ignore
    return client


def get_usage(response, model=None) -> Usage:
    """Extract usage info from an AWS Bedrock response."""
    response_body = json.loads(response['body'].read().decode())
    usage: Usage = {'input_tokens': response_body['usage']['input_tokens'],
                    'output_tokens': response_body['usage']['output_tokens']}

    # add a cost field
    try:
        costs = pricing[model]  # model name passed in request (may be alias)
    except KeyError:
        raise ValueError(f"Don't know prices for model {model} or {response.model}")

    cost = (usage.get('input_tokens', 0) * costs['input']
            + usage.get('output_tokens', 0) * costs['output']) / 1_000
    usage['cost'] = cost
    return usage, response_body


def read_usage(path: pathlib.Path = default_usage_file) -> Usage:
    """Retrieve total usage logged in a file."""
    if os.path.exists(path):
        with open(path, "rt") as f:
            return json.load(f)
    else:
        return {}


def _write_usage(u: Usage, path: pathlib.Path):
    with open(path, "wt") as f:
        json.dump(u, f, indent=4)


def _merge_usage(u1: Usage, u2: Usage) -> Usage:
    return {k: u1.get(k, 0) + u2.get(k, 0) for k in itertools.chain(u1, u2)}


def new_default_client(default='boto3') -> boto3.client:
    """Set the `default_client` to a new tracked client, based on the current
    aws credentials. If your credentials change you should call this method again."""
    global default_client
    default_client = track_usage(
        boto3.client('bedrock-runtime', region_name='us-west-2'))  # create a client with default args, and modify it
    # so that it will store its usage in a local file
    return default_client


# new_default_client()       # set `default_client` right away when importing this module

if __name__ == "__main__":
    import doctest

    doctest.testmod()