File size: 4,517 Bytes
5a3fcad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d15e904
5a3fcad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
LLM Model Utilities

This module provides helper functions for initializing LLM models
used by the NexDatawork agents.

Supported Models:
    - Azure OpenAI (GPT-4, GPT-3.5-turbo)
    - OpenAI API (GPT-4, GPT-3.5-turbo)
    
Environment Variables Required:
    For Azure:
        - AZURE_OPENAI_ENDPOINT
        - AZURE_OPENAI_API_KEY
    For OpenAI:
        - OPENAI_API_KEY
"""

import os
from typing import Any

from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler


def create_azure_model(
    deployment_name: str = "gpt-4.1",
    api_version: str = "2025-01-01-preview",
    streaming: bool = True,
    temperature: float = 0.0
) -> Any:
    """
    Create an Azure OpenAI chat model instance.
    
    This function initializes an AzureChatOpenAI model with the specified
    configuration. Requires AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY
    environment variables to be set.
    
    Args:
        deployment_name: The Azure deployment name (default: "gpt-4.1").
        api_version: Azure API version string (default: "2025-01-01-preview").
        streaming: Enable streaming responses (default: True).
        temperature: Model temperature for randomness (default: 0.0 for deterministic).
        
    Returns:
        AzureChatOpenAI: Configured Azure OpenAI model instance.
        
    Raises:
        ImportError: If langchain_openai is not installed.
        ValueError: If required environment variables are not set.
        
    Example:
        >>> model = create_azure_model(deployment_name="gpt-4")
        >>> response = model.invoke("Hello!")
    """
    from langchain_openai import AzureChatOpenAI
    
    # Validate environment variables
    endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
    if not endpoint:
        raise ValueError("AZURE_OPENAI_ENDPOINT environment variable not set")
    
    # Configure callbacks for streaming
    callbacks = [StreamingStdOutCallbackHandler()] if streaming else []
    
    return AzureChatOpenAI(
        openai_api_version=api_version,
        azure_deployment=deployment_name,
        azure_endpoint=endpoint,
        streaming=streaming,
        callbacks=callbacks,
        temperature=temperature,
    )


def create_openai_model(
    model_name: str = "gpt-4-turbo-preview",
    streaming: bool = True,
    temperature: float = 0.0
) -> Any:
    """
    Create an OpenAI chat model instance.
    
    This function initializes a ChatOpenAI model with the specified
    configuration. Requires OPENAI_API_KEY environment variable.
    
    Args:
        model_name: The OpenAI model name (default: "gpt-4-turbo-preview").
        streaming: Enable streaming responses (default: True).
        temperature: Model temperature for randomness (default: 0.0).
        
    Returns:
        ChatOpenAI: Configured OpenAI model instance.
        
    Raises:
        ImportError: If langchain_openai is not installed.
        ValueError: If OPENAI_API_KEY is not set.
        
    Example:
        >>> model = create_openai_model(model_name="gpt-4")
        >>> response = model.invoke("Analyze this data...")
    """
    from langchain_openai import ChatOpenAI
    
    # Validate environment variable
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise ValueError("OPENAI_API_KEY environment variable not set")
    
    # Configure callbacks for streaming
    callbacks = [StreamingStdOutCallbackHandler()] if streaming else []
    
    return ChatOpenAI(
        model=model_name,
        streaming=streaming,
        callbacks=callbacks,
        temperature=temperature,
    )


def get_model(provider: str = "azure", **kwargs) -> Any:
    """
    Factory function to create an LLM model based on provider.
    
    This is a convenience function that delegates to the appropriate
    model creation function based on the provider argument.
    
    Args:
        provider: Either "azure" or "openai" (default: "azure").
        **kwargs: Additional arguments passed to the model creation function.
        
    Returns:
        The configured LLM model instance.
        
    Example:
        >>> model = get_model("openai", model_name="gpt-4")
        >>> model = get_model("azure", deployment_name="gpt-4.1")
    """
    if provider.lower() == "azure":
        return create_azure_model(**kwargs)
    elif provider.lower() == "openai":
        return create_openai_model(**kwargs)
    else:
        raise ValueError(f"Unknown provider: {provider}. Use 'azure' or 'openai'.")