File size: 6,125 Bytes
f256e6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import netrc
import os
from typing import Any, Dict, List, Optional, Tuple, Union

import streamlit as st


def initialize_env_variables(
    prod_mode: bool = False, 
    additional_env_vars: Optional[Union[List[str], Dict[str,Any]]] = None
) -> None:
    if additional_env_vars is None:
        additional_env_vars = []

    if not prod_mode:
        # In non-prod mode, prioritize environment variables
        st.session_state.SAMBANOVA_API_KEY = os.environ.get(
            'SAMBANOVA_API_KEY', st.session_state.get('SAMBANOVA_API_KEY', '')
        )
        st.session_state.SAMBASTUDIO_URL = os.environ.get(
            'SAMBASTUDIO_URL', st.session_state.get('SAMBASTUDIO_URL', '')
        )
        st.session_state.SAMBASTUDIO_API_KEY = os.environ.get(
            'SAMBASTUDIO_API_KEY', st.session_state.get('SAMBASTUDIO_API_KEY', '')
        )
        if isinstance(additional_env_vars, List):
            for var in additional_env_vars:
                st.session_state[var] = os.environ.get(var, st.session_state.get(var, ''))
        elif isinstance(additional_env_vars, Dict):
            for key, val in additional_env_vars.items():
                st.session_state[key] = os.environ.get(key, st.session_state.get(key, val))
        else:
            raise ValueError('additional_env_vars must be a List or Dict')
    else:
        # In prod mode, only use session state
        if 'SAMBANOVA_API_KEY' not in st.session_state:
            st.session_state.SAMBANOVA_API_KEY = ''
        if 'SAMBASTUDIO_URL' not in st.session_state:
            st.session_state.SAMBASTUDIO_URL = ''
        if 'SAMBASTUDIO_API_KEY' not in st.session_state:
            st.session_state.SAMBASTUDIO_API_KEY = ''
        
        if isinstance(additional_env_vars, List):
            for var in additional_env_vars:
                if var not in st.session_state:
                    st.session_state[var] = ''
        elif isinstance(additional_env_vars, Dict):
            for key, val in additional_env_vars.items():
                if key not in st.session_state:
                    st.session_state[key] = val
        else:
            raise ValueError('additional_env_vars must be a List or Dict')


def set_env_variables(
    api_key: str, 
    additional_vars: Optional[Dict[str, Any]] = None, 
    prod_mode: bool = False
) -> None:
    st.session_state.SAMBANOVA_API_KEY = api_key
    if additional_vars:
        for key, value in additional_vars.items():
            st.session_state[key] = value
    if not prod_mode:
        # In non-prod mode, also set environment variables
        os.environ['SAMBANOVA_API_KEY'] = api_key
        if additional_vars:
            for key, value in additional_vars.items():
                os.environ[key] = value


def env_input_fields(
    additional_env_vars: Union[List[str], Dict[str, Any]] = None, 
    mode: str = 'SambaNova Cloud'
) -> Tuple[str, Any]:
    if additional_env_vars is None:
        additional_env_vars = []

    if mode == 'SambaNova Cloud':
        api_key = st.text_input(
            'SAMBANOVA CLOUD API KEY', value=st.session_state.get('SAMBANOVA_API_KEY', ''), type='password'
        )
        additional_vars = {}
    elif mode == 'SambaStudio':
        url = st.text_input('SAMBASTUDIO URL', value=st.session_state.get('SAMBASTUDIO_URL', ''), type='password')
        api_key = st.text_input(
            'SAMBASTUDIO API KEY', value=st.session_state.get('SAMBASTUDIO_API_KEY', ''), type='password'
        )
        additional_vars = {}
        additional_vars['SAMBASTUDIO_URL'] = url
    else:
        raise Exception('Setup mode not supported.')

    if isinstance(additional_env_vars, List):
        for var in additional_env_vars:
            additional_vars[var] = st.text_input(f'{var}', value=st.session_state.get(var, ''), type='password')
    elif isinstance(additional_env_vars, Dict):
        for key, val in additional_env_vars.items():
            if mode == 'SambaStudio' and key == 'SAMBANOVA_URL':
                continue
            additional_vars[key] = st.text_input(f'{key}', value=st.session_state.get(key, val), type='password')

    return api_key, additional_vars


def are_credentials_set(additional_env_vars: Optional[Union[List[str], Dict[str, Any]]] = None) -> bool:
    if additional_env_vars is None:
        additional_env_vars = []

    base_creds_set = bool(st.session_state.SAMBANOVA_API_KEY)
    
    if isinstance(additional_env_vars, List):
        additional_creds_set = all(bool(st.session_state.get(var, '')) for var in additional_env_vars)
    elif isinstance(additional_env_vars, Dict):
        additional_creds_set = all(bool(st.session_state.get(key, '')) for key, _ in additional_env_vars.items())

    return base_creds_set and additional_creds_set


def save_credentials(api_key: str, additional_vars: Optional[Dict[str, Any]] = None, prod_mode: bool = False) -> str:
    set_env_variables(api_key, additional_vars, prod_mode)
    return 'Credentials saved successfully!'


import netrc
import os
from typing import Optional


def get_wandb_key() -> Optional[str]:
    """
    Retrieve the Weights & Biases API key from the environment or ~/.netrc,
    and remove WANDB_API_KEY from the environment to prevent conflicts with weave.

    Returns:
        The API key if found, otherwise None.
    """
    # Check for WANDB_API_KEY in environment variables
    env_wandb_api_key = os.environ.pop('WANDB_API_KEY', None)

    # Check for WANDB_API_KEY in ~/.netrc
    netrc_wandb_api_key = None
    try:
        netrc_path = os.path.expanduser('~/.netrc')
        netrc_data = netrc.netrc(netrc_path)
        auth = netrc_data.authenticators('api.wandb.ai')
        if auth and len(auth) == 3:
            netrc_wandb_api_key = auth[2]  # The password (API key) is the third element
    except (FileNotFoundError, netrc.NetrcParseError):
        pass

    # Return the API key from the environment variable if it was set
    if env_wandb_api_key:
        return env_wandb_api_key
    elif netrc_wandb_api_key:
        return netrc_wandb_api_key

    # If neither is set, return None
    return None