LinkedIn-Post-Generator / utils /credential_manager.py
samagra44
initial commit
f154798
"""
Credential management utilities for Azure OpenAI
"""
import os
import streamlit as st
from typing import Dict, Optional, Tuple
class CredentialManager:
"""Manages Azure OpenAI credentials for the application"""
@staticmethod
def get_credentials_from_env() -> Dict[str, str]:
"""Get credentials from environment variables"""
return {
'api_key': os.getenv('AZURE_OPENAI_API_KEY', ''),
'endpoint': os.getenv('AZURE_OPENAI_ENDPOINT', ''),
'deployment': os.getenv('AZURE_DEPLOYMENT', ''),
'api_version': os.getenv('OPENAI_API_VERSION', '2024-02-01')
}
@staticmethod
def get_credentials_from_session() -> Dict[str, str]:
"""Get credentials from Streamlit session state"""
return {
'api_key': st.session_state.get('azure_api_key', ''),
'endpoint': st.session_state.get('azure_endpoint', ''),
'deployment': st.session_state.get('azure_deployment', ''),
'api_version': st.session_state.get('azure_api_version', '2024-02-01')
}
@staticmethod
def save_credentials_to_session(api_key: str, endpoint: str, deployment: str, api_version: str = '2024-02-01'):
"""Save credentials to Streamlit session state"""
st.session_state.azure_api_key = api_key
st.session_state.azure_endpoint = endpoint
st.session_state.azure_deployment = deployment
st.session_state.azure_api_version = api_version
@staticmethod
def set_environment_variables(credentials: Dict[str, str]):
"""Set environment variables for Azure OpenAI"""
os.environ['AZURE_OPENAI_API_KEY'] = credentials['api_key']
os.environ['AZURE_OPENAI_ENDPOINT'] = credentials['endpoint']
os.environ['AZURE_DEPLOYMENT'] = credentials['deployment']
os.environ['OPENAI_API_VERSION'] = credentials['api_version']
@staticmethod
def validate_credentials(credentials: Dict[str, str]) -> Tuple[bool, str]:
"""Validate Azure OpenAI credentials"""
if not credentials['api_key']:
return False, "API Key is required"
if not credentials['endpoint']:
return False, "Endpoint is required"
if not credentials['deployment']:
return False, "Deployment name is required"
# Basic format validation for Azure OpenAI
if len(credentials['api_key']) < 10:
return False, "API Key seems too short"
if not credentials['endpoint'].startswith('https://'):
return False, "Endpoint should be a valid HTTPS URL"
if not credentials['endpoint'].endswith('.openai.azure.com/'):
return False, "Endpoint should end with '.openai.azure.com/'"
return True, "Credentials are valid"
@staticmethod
def get_current_credentials(use_env: bool = True) -> Dict[str, str]:
"""Get current credentials based on preference"""
if use_env:
return CredentialManager.get_credentials_from_env()
else:
return CredentialManager.get_credentials_from_session()
@staticmethod
def clear_session_credentials():
"""Clear credentials from session state"""
if 'azure_api_key' in st.session_state:
del st.session_state.azure_api_key
if 'azure_endpoint' in st.session_state:
del st.session_state.azure_endpoint
if 'azure_deployment' in st.session_state:
del st.session_state.azure_deployment
if 'azure_api_version' in st.session_state:
del st.session_state.azure_api_version
def format_endpoint_url(endpoint: str) -> str:
"""Format endpoint URL to ensure it ends with a slash"""
if not endpoint.endswith('/'):
endpoint += '/'
return endpoint
def mask_api_key(api_key: str) -> str:
"""Mask API key for display purposes"""
if len(api_key) > 8:
return api_key[:4] + '*' * (len(api_key) - 8) + api_key[-4:]
return '*' * len(api_key)