File size: 5,059 Bytes
82bf89e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import streamlit as st

from langchain_core.messages import HumanMessage, SystemMessage
from langchain_aws import ChatBedrock
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI

from typing import Optional
from config import MODEL_OPTIONS


def create_llm_model(llm_provider: str, **kwargs):
    """Create a language model based on the selected provider."""
    params = st.session_state.get('params')

    if llm_provider == "OpenAI":
        return ChatOpenAI(
            base_url=params.get("base_url"),
            openai_api_key=params.get("api_key"),
            model=MODEL_OPTIONS['OpenAI'],
            temperature=kwargs.get('temperature', 0.7),
        )
    elif llm_provider == "Antropic":
        return ChatAnthropic(
            base_url=params.get("base_url"),
            anthropic_api_key=params.get("api_key"),
            model=MODEL_OPTIONS['Antropic'],
            temperature=kwargs.get('temperature', 0.7),
        )
    elif llm_provider == "Bedrock":
        import boto3
        # Initialize Bedrock client
        _bedrock = boto3.client(
            'bedrock-runtime',
            region_name=params.get("region_name"),
            aws_access_key_id=params.get("aws_access_key"),
            aws_secret_access_key=params.get("aws_secret_key"),
        )
        return ChatBedrock(
            client=_bedrock,
            model_id=MODEL_OPTIONS['Bedrock'],
            **kwargs
        )

    elif llm_provider == "Google":
        return ChatGoogleGenerativeAI(
            google_api_key=params.get("api_key"),
            model=MODEL_OPTIONS['Google'],
            temperature=kwargs.get('temperature', 0.7),
            max_tokens=kwargs.get('max_tokens', 4096),
            max_retries=2,
        )
    elif llm_provider == "Groq":
        return ChatGroq(
            api_key=params.get("api_key"),  # groq_api_key expected here
            model=MODEL_OPTIONS['Groq'],
            temperature=kwargs.get("temperature", 0.7),
            streaming=kwargs.get("streaming", False)
        )
    else:
        raise ValueError(f"Unsupported LLM provider: {llm_provider}")
    

def get_response(prompt: str, llm_provider: str):
    """Get a response from the LLM using the standard LangChain interface."""
    try:
        # Create the LLM instance dynamically
        llm = create_llm_model(llm_provider)

        # Wrap prompt in a HumanMessage
        message = HumanMessage(content=prompt)

        # Invoke model and return the output content
        response = llm.invoke([message])
        return response.content

    except Exception as e:
        return f"Error during LLM invocation: {str(e)}"

def get_response_stream(
    prompt: str,
    llm_provider: str,
    system: Optional[str] = '',
    temperature: float = 1.0,
    max_tokens: int = 4096,
    **kwargs,
    ):
    """
    Get a streaming response from the selected LLM provider.
    All provider-specific connection/auth should be handled via kwargs.
    """
    try:
        # Add streaming and generation params to kwargs
        kwargs.update({
            "temperature": temperature,
            "max_tokens": max_tokens,
            "streaming": True
        })

        # Create the LLM with streaming enabled
        llm = create_llm_model(llm_provider, **kwargs)

        # Compose messages
        messages = []
        if system:
            messages.append(SystemMessage(content=system))
        messages.append(HumanMessage(content=prompt))

        # Stream the response
        stream_response = llm.stream(messages)
        return stream_response
    except Exception as e:
        st.error(f"[Error during streaming: {str(e)}]")
        st.stop()


def test_llm_connection(llm_provider: str, test_params: dict = None):
    """
    Test the connection to the specified LLM provider.
    Returns a tuple of (success: bool, message: str)
    """
    try:
        # Use test_params if provided, otherwise use session state
        if test_params:
            # Temporarily store original params
            original_params = st.session_state.get('params', {})
            # Set test params
            st.session_state['params'] = test_params
        
        # Create LLM instance
        llm = create_llm_model(llm_provider)
        
        # Test with a simple message
        test_message = HumanMessage(content="Hello, this is a connection test. Please respond with 'OK'.")
        response = llm.invoke([test_message])
        
        # Restore original params if we used test_params
        if test_params:
            st.session_state['params'] = original_params
            
        return True, f"✅ Connection successful! Model response: {response.content[:100]}..."
        
    except Exception as e:
        # Restore original params if we used test_params
        if test_params:
            st.session_state['params'] = original_params
        return False, f"❌ Connection failed: {str(e)}"