File size: 4,093 Bytes
f154798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Credential management utilities for Azure OpenAI
"""

import os
import streamlit as st
from typing import Dict, Optional, Tuple

class CredentialManager:
    """Manages Azure OpenAI credentials for the application"""
    
    @staticmethod
    def get_credentials_from_env() -> Dict[str, str]:
        """Get credentials from environment variables"""
        return {
            'api_key': os.getenv('AZURE_OPENAI_API_KEY', ''),
            'endpoint': os.getenv('AZURE_OPENAI_ENDPOINT', ''),
            'deployment': os.getenv('AZURE_DEPLOYMENT', ''),
            'api_version': os.getenv('OPENAI_API_VERSION', '2024-02-01')
        }
    
    @staticmethod
    def get_credentials_from_session() -> Dict[str, str]:
        """Get credentials from Streamlit session state"""
        return {
            'api_key': st.session_state.get('azure_api_key', ''),
            'endpoint': st.session_state.get('azure_endpoint', ''),
            'deployment': st.session_state.get('azure_deployment', ''),
            'api_version': st.session_state.get('azure_api_version', '2024-02-01')
        }
    
    @staticmethod
    def save_credentials_to_session(api_key: str, endpoint: str, deployment: str, api_version: str = '2024-02-01'):
        """Save credentials to Streamlit session state"""
        st.session_state.azure_api_key = api_key
        st.session_state.azure_endpoint = endpoint
        st.session_state.azure_deployment = deployment
        st.session_state.azure_api_version = api_version
    
    @staticmethod
    def set_environment_variables(credentials: Dict[str, str]):
        """Set environment variables for Azure OpenAI"""
        os.environ['AZURE_OPENAI_API_KEY'] = credentials['api_key']
        os.environ['AZURE_OPENAI_ENDPOINT'] = credentials['endpoint']
        os.environ['AZURE_DEPLOYMENT'] = credentials['deployment']
        os.environ['OPENAI_API_VERSION'] = credentials['api_version']
    
    @staticmethod
    def validate_credentials(credentials: Dict[str, str]) -> Tuple[bool, str]:
        """Validate Azure OpenAI credentials"""
        if not credentials['api_key']:
            return False, "API Key is required"
        
        if not credentials['endpoint']:
            return False, "Endpoint is required"
        
        if not credentials['deployment']:
            return False, "Deployment name is required"
        
        # Basic format validation for Azure OpenAI
        if len(credentials['api_key']) < 10:
            return False, "API Key seems too short"
        
        if not credentials['endpoint'].startswith('https://'):
            return False, "Endpoint should be a valid HTTPS URL"
        
        if not credentials['endpoint'].endswith('.openai.azure.com/'):
            return False, "Endpoint should end with '.openai.azure.com/'"
        
        return True, "Credentials are valid"
    
    @staticmethod
    def get_current_credentials(use_env: bool = True) -> Dict[str, str]:
        """Get current credentials based on preference"""
        if use_env:
            return CredentialManager.get_credentials_from_env()
        else:
            return CredentialManager.get_credentials_from_session()
    
    @staticmethod
    def clear_session_credentials():
        """Clear credentials from session state"""
        if 'azure_api_key' in st.session_state:
            del st.session_state.azure_api_key
        if 'azure_endpoint' in st.session_state:
            del st.session_state.azure_endpoint
        if 'azure_deployment' in st.session_state:
            del st.session_state.azure_deployment
        if 'azure_api_version' in st.session_state:
            del st.session_state.azure_api_version

def format_endpoint_url(endpoint: str) -> str:
    """Format endpoint URL to ensure it ends with a slash"""
    if not endpoint.endswith('/'):
        endpoint += '/'
    return endpoint

def mask_api_key(api_key: str) -> str:
    """Mask API key for display purposes"""
    if len(api_key) > 8:
        return api_key[:4] + '*' * (len(api_key) - 8) + api_key[-4:]
    return '*' * len(api_key)