Spaces:
Sleeping
Sleeping
File size: 6,125 Bytes
f256e6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import netrc
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import streamlit as st
def initialize_env_variables(
prod_mode: bool = False,
additional_env_vars: Optional[Union[List[str], Dict[str,Any]]] = None
) -> None:
if additional_env_vars is None:
additional_env_vars = []
if not prod_mode:
# In non-prod mode, prioritize environment variables
st.session_state.SAMBANOVA_API_KEY = os.environ.get(
'SAMBANOVA_API_KEY', st.session_state.get('SAMBANOVA_API_KEY', '')
)
st.session_state.SAMBASTUDIO_URL = os.environ.get(
'SAMBASTUDIO_URL', st.session_state.get('SAMBASTUDIO_URL', '')
)
st.session_state.SAMBASTUDIO_API_KEY = os.environ.get(
'SAMBASTUDIO_API_KEY', st.session_state.get('SAMBASTUDIO_API_KEY', '')
)
if isinstance(additional_env_vars, List):
for var in additional_env_vars:
st.session_state[var] = os.environ.get(var, st.session_state.get(var, ''))
elif isinstance(additional_env_vars, Dict):
for key, val in additional_env_vars.items():
st.session_state[key] = os.environ.get(key, st.session_state.get(key, val))
else:
raise ValueError('additional_env_vars must be a List or Dict')
else:
# In prod mode, only use session state
if 'SAMBANOVA_API_KEY' not in st.session_state:
st.session_state.SAMBANOVA_API_KEY = ''
if 'SAMBASTUDIO_URL' not in st.session_state:
st.session_state.SAMBASTUDIO_URL = ''
if 'SAMBASTUDIO_API_KEY' not in st.session_state:
st.session_state.SAMBASTUDIO_API_KEY = ''
if isinstance(additional_env_vars, List):
for var in additional_env_vars:
if var not in st.session_state:
st.session_state[var] = ''
elif isinstance(additional_env_vars, Dict):
for key, val in additional_env_vars.items():
if key not in st.session_state:
st.session_state[key] = val
else:
raise ValueError('additional_env_vars must be a List or Dict')
def set_env_variables(
api_key: str,
additional_vars: Optional[Dict[str, Any]] = None,
prod_mode: bool = False
) -> None:
st.session_state.SAMBANOVA_API_KEY = api_key
if additional_vars:
for key, value in additional_vars.items():
st.session_state[key] = value
if not prod_mode:
# In non-prod mode, also set environment variables
os.environ['SAMBANOVA_API_KEY'] = api_key
if additional_vars:
for key, value in additional_vars.items():
os.environ[key] = value
def env_input_fields(
additional_env_vars: Union[List[str], Dict[str, Any]] = None,
mode: str = 'SambaNova Cloud'
) -> Tuple[str, Any]:
if additional_env_vars is None:
additional_env_vars = []
if mode == 'SambaNova Cloud':
api_key = st.text_input(
'SAMBANOVA CLOUD API KEY', value=st.session_state.get('SAMBANOVA_API_KEY', ''), type='password'
)
additional_vars = {}
elif mode == 'SambaStudio':
url = st.text_input('SAMBASTUDIO URL', value=st.session_state.get('SAMBASTUDIO_URL', ''), type='password')
api_key = st.text_input(
'SAMBASTUDIO API KEY', value=st.session_state.get('SAMBASTUDIO_API_KEY', ''), type='password'
)
additional_vars = {}
additional_vars['SAMBASTUDIO_URL'] = url
else:
raise Exception('Setup mode not supported.')
if isinstance(additional_env_vars, List):
for var in additional_env_vars:
additional_vars[var] = st.text_input(f'{var}', value=st.session_state.get(var, ''), type='password')
elif isinstance(additional_env_vars, Dict):
for key, val in additional_env_vars.items():
if mode == 'SambaStudio' and key == 'SAMBANOVA_URL':
continue
additional_vars[key] = st.text_input(f'{key}', value=st.session_state.get(key, val), type='password')
return api_key, additional_vars
def are_credentials_set(additional_env_vars: Optional[Union[List[str], Dict[str, Any]]] = None) -> bool:
if additional_env_vars is None:
additional_env_vars = []
base_creds_set = bool(st.session_state.SAMBANOVA_API_KEY)
if isinstance(additional_env_vars, List):
additional_creds_set = all(bool(st.session_state.get(var, '')) for var in additional_env_vars)
elif isinstance(additional_env_vars, Dict):
additional_creds_set = all(bool(st.session_state.get(key, '')) for key, _ in additional_env_vars.items())
return base_creds_set and additional_creds_set
def save_credentials(api_key: str, additional_vars: Optional[Dict[str, Any]] = None, prod_mode: bool = False) -> str:
set_env_variables(api_key, additional_vars, prod_mode)
return 'Credentials saved successfully!'
import netrc
import os
from typing import Optional
def get_wandb_key() -> Optional[str]:
"""
Retrieve the Weights & Biases API key from the environment or ~/.netrc,
and remove WANDB_API_KEY from the environment to prevent conflicts with weave.
Returns:
The API key if found, otherwise None.
"""
# Check for WANDB_API_KEY in environment variables
env_wandb_api_key = os.environ.pop('WANDB_API_KEY', None)
# Check for WANDB_API_KEY in ~/.netrc
netrc_wandb_api_key = None
try:
netrc_path = os.path.expanduser('~/.netrc')
netrc_data = netrc.netrc(netrc_path)
auth = netrc_data.authenticators('api.wandb.ai')
if auth and len(auth) == 3:
netrc_wandb_api_key = auth[2] # The password (API key) is the third element
except (FileNotFoundError, netrc.NetrcParseError):
pass
# Return the API key from the environment variable if it was set
if env_wandb_api_key:
return env_wandb_api_key
elif netrc_wandb_api_key:
return netrc_wandb_api_key
# If neither is set, return None
return None |