File size: 5,059 Bytes
82bf89e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import streamlit as st
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_aws import ChatBedrock
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from typing import Optional
from config import MODEL_OPTIONS
def create_llm_model(llm_provider: str, **kwargs):
"""Create a language model based on the selected provider."""
params = st.session_state.get('params')
if llm_provider == "OpenAI":
return ChatOpenAI(
base_url=params.get("base_url"),
openai_api_key=params.get("api_key"),
model=MODEL_OPTIONS['OpenAI'],
temperature=kwargs.get('temperature', 0.7),
)
elif llm_provider == "Antropic":
return ChatAnthropic(
base_url=params.get("base_url"),
anthropic_api_key=params.get("api_key"),
model=MODEL_OPTIONS['Antropic'],
temperature=kwargs.get('temperature', 0.7),
)
elif llm_provider == "Bedrock":
import boto3
# Initialize Bedrock client
_bedrock = boto3.client(
'bedrock-runtime',
region_name=params.get("region_name"),
aws_access_key_id=params.get("aws_access_key"),
aws_secret_access_key=params.get("aws_secret_key"),
)
return ChatBedrock(
client=_bedrock,
model_id=MODEL_OPTIONS['Bedrock'],
**kwargs
)
elif llm_provider == "Google":
return ChatGoogleGenerativeAI(
google_api_key=params.get("api_key"),
model=MODEL_OPTIONS['Google'],
temperature=kwargs.get('temperature', 0.7),
max_tokens=kwargs.get('max_tokens', 4096),
max_retries=2,
)
elif llm_provider == "Groq":
return ChatGroq(
api_key=params.get("api_key"), # groq_api_key expected here
model=MODEL_OPTIONS['Groq'],
temperature=kwargs.get("temperature", 0.7),
streaming=kwargs.get("streaming", False)
)
else:
raise ValueError(f"Unsupported LLM provider: {llm_provider}")
def get_response(prompt: str, llm_provider: str):
"""Get a response from the LLM using the standard LangChain interface."""
try:
# Create the LLM instance dynamically
llm = create_llm_model(llm_provider)
# Wrap prompt in a HumanMessage
message = HumanMessage(content=prompt)
# Invoke model and return the output content
response = llm.invoke([message])
return response.content
except Exception as e:
return f"Error during LLM invocation: {str(e)}"
def get_response_stream(
prompt: str,
llm_provider: str,
system: Optional[str] = '',
temperature: float = 1.0,
max_tokens: int = 4096,
**kwargs,
):
"""
Get a streaming response from the selected LLM provider.
All provider-specific connection/auth should be handled via kwargs.
"""
try:
# Add streaming and generation params to kwargs
kwargs.update({
"temperature": temperature,
"max_tokens": max_tokens,
"streaming": True
})
# Create the LLM with streaming enabled
llm = create_llm_model(llm_provider, **kwargs)
# Compose messages
messages = []
if system:
messages.append(SystemMessage(content=system))
messages.append(HumanMessage(content=prompt))
# Stream the response
stream_response = llm.stream(messages)
return stream_response
except Exception as e:
st.error(f"[Error during streaming: {str(e)}]")
st.stop()
def test_llm_connection(llm_provider: str, test_params: dict = None):
"""
Test the connection to the specified LLM provider.
Returns a tuple of (success: bool, message: str)
"""
try:
# Use test_params if provided, otherwise use session state
if test_params:
# Temporarily store original params
original_params = st.session_state.get('params', {})
# Set test params
st.session_state['params'] = test_params
# Create LLM instance
llm = create_llm_model(llm_provider)
# Test with a simple message
test_message = HumanMessage(content="Hello, this is a connection test. Please respond with 'OK'.")
response = llm.invoke([test_message])
# Restore original params if we used test_params
if test_params:
st.session_state['params'] = original_params
return True, f"✅ Connection successful! Model response: {response.content[:100]}..."
except Exception as e:
# Restore original params if we used test_params
if test_params:
st.session_state['params'] = original_params
return False, f"❌ Connection failed: {str(e)}" |