Spaces:
Sleeping
Sleeping
| """ | |
| Credential management utilities for Azure OpenAI | |
| """ | |
| import os | |
| import streamlit as st | |
| from typing import Dict, Optional, Tuple | |
| class CredentialManager: | |
| """Manages Azure OpenAI credentials for the application""" | |
| def get_credentials_from_env() -> Dict[str, str]: | |
| """Get credentials from environment variables""" | |
| return { | |
| 'api_key': os.getenv('AZURE_OPENAI_API_KEY', ''), | |
| 'endpoint': os.getenv('AZURE_OPENAI_ENDPOINT', ''), | |
| 'deployment': os.getenv('AZURE_DEPLOYMENT', ''), | |
| 'api_version': os.getenv('OPENAI_API_VERSION', '2024-02-01') | |
| } | |
| def get_credentials_from_session() -> Dict[str, str]: | |
| """Get credentials from Streamlit session state""" | |
| return { | |
| 'api_key': st.session_state.get('azure_api_key', ''), | |
| 'endpoint': st.session_state.get('azure_endpoint', ''), | |
| 'deployment': st.session_state.get('azure_deployment', ''), | |
| 'api_version': st.session_state.get('azure_api_version', '2024-02-01') | |
| } | |
| def save_credentials_to_session(api_key: str, endpoint: str, deployment: str, api_version: str = '2024-02-01'): | |
| """Save credentials to Streamlit session state""" | |
| st.session_state.azure_api_key = api_key | |
| st.session_state.azure_endpoint = endpoint | |
| st.session_state.azure_deployment = deployment | |
| st.session_state.azure_api_version = api_version | |
| def set_environment_variables(credentials: Dict[str, str]): | |
| """Set environment variables for Azure OpenAI""" | |
| os.environ['AZURE_OPENAI_API_KEY'] = credentials['api_key'] | |
| os.environ['AZURE_OPENAI_ENDPOINT'] = credentials['endpoint'] | |
| os.environ['AZURE_DEPLOYMENT'] = credentials['deployment'] | |
| os.environ['OPENAI_API_VERSION'] = credentials['api_version'] | |
| def validate_credentials(credentials: Dict[str, str]) -> Tuple[bool, str]: | |
| """Validate Azure OpenAI credentials""" | |
| if not credentials['api_key']: | |
| return False, "API Key is required" | |
| if not credentials['endpoint']: | |
| return False, "Endpoint is required" | |
| if not credentials['deployment']: | |
| return False, "Deployment name is required" | |
| # Basic format validation for Azure OpenAI | |
| if len(credentials['api_key']) < 10: | |
| return False, "API Key seems too short" | |
| if not credentials['endpoint'].startswith('https://'): | |
| return False, "Endpoint should be a valid HTTPS URL" | |
| if not credentials['endpoint'].endswith('.openai.azure.com/'): | |
| return False, "Endpoint should end with '.openai.azure.com/'" | |
| return True, "Credentials are valid" | |
| def get_current_credentials(use_env: bool = True) -> Dict[str, str]: | |
| """Get current credentials based on preference""" | |
| if use_env: | |
| return CredentialManager.get_credentials_from_env() | |
| else: | |
| return CredentialManager.get_credentials_from_session() | |
| def clear_session_credentials(): | |
| """Clear credentials from session state""" | |
| if 'azure_api_key' in st.session_state: | |
| del st.session_state.azure_api_key | |
| if 'azure_endpoint' in st.session_state: | |
| del st.session_state.azure_endpoint | |
| if 'azure_deployment' in st.session_state: | |
| del st.session_state.azure_deployment | |
| if 'azure_api_version' in st.session_state: | |
| del st.session_state.azure_api_version | |
| def format_endpoint_url(endpoint: str) -> str: | |
| """Format endpoint URL to ensure it ends with a slash""" | |
| if not endpoint.endswith('/'): | |
| endpoint += '/' | |
| return endpoint | |
| def mask_api_key(api_key: str) -> str: | |
| """Mask API key for display purposes""" | |
| if len(api_key) > 8: | |
| return api_key[:4] + '*' * (len(api_key) - 8) + api_key[-4:] | |
| return '*' * len(api_key) |