chatbot1 / src /chat_engine.py
Jack-ki1's picture
Upload 16 files
00bd2b1 verified
import os
import logging
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from src.config import TOPIC_REGISTRY, MODEL_NAME, TEMPERATURE, MAX_TOKENS
from src.models import TutorResponse
# Conditional imports based on available API
try:
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAI
GOOGLE_API_AVAILABLE = True
except ImportError:
GOOGLE_API_AVAILABLE = False
logging.warning("Google Generative AI library not available")
try:
from langchain_huggingface import HuggingFaceEndpoint
HUGGINGFACE_API_AVAILABLE = True
except ImportError:
HUGGINGFACE_API_AVAILABLE = False
logging.warning("HuggingFace library not available")
try:
from langchain_openai import ChatOpenAI
OPENAI_API_AVAILABLE = True
except ImportError:
OPENAI_API_AVAILABLE = False
logging.warning("OpenAI library not available")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_llm():
# Determine which API to use based on environment variables
api_type = os.getenv("API_TYPE", "huggingface").lower()
if api_type == "google" and GOOGLE_API_AVAILABLE:
return get_google_llm()
elif api_type == "openai" and OPENAI_API_AVAILABLE:
return get_openai_llm()
elif api_type == "huggingface" and HUGGINGFACE_API_AVAILABLE:
return get_huggingface_llm()
else:
# Fallback to HuggingFace if preferred option is not available
if HUGGINGFACE_API_AVAILABLE:
return get_huggingface_llm()
elif GOOGLE_API_AVAILABLE:
return get_google_llm()
elif OPENAI_API_AVAILABLE:
return get_openai_llm()
else:
raise RuntimeError("No suitable LLM API available. Please install one of: langchain-google-genai, langchain-huggingface, langchain-openai")
def get_google_llm():
key = os.getenv("GOOGLE_API_KEY")
if not key:
raise RuntimeError("GOOGLE_API_KEY is required for Google API")
# Ensure model name is set with fallback to a more current default
model_name = MODEL_NAME if MODEL_NAME else "gemini-1.5-flash"
logger.info(f"Initializing Google LLM with model: {model_name}")
return ChatGoogleGenerativeAI(
model=model_name,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
google_api_key=key,
convert_system_message_to_human=True # Required for Gemini in LangChain
)
def get_openai_llm():
key = os.getenv("OPENAI_API_KEY")
if not key:
raise RuntimeError("OPENAI_API_KEY is required for OpenAI API")
# Ensure model name is set
model_name = MODEL_NAME if MODEL_NAME else "gpt-3.5-turbo"
logger.info(f"Initializing OpenAI LLM with model: {model_name}")
return ChatOpenAI(
model_name=model_name,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
openai_api_key=key
)
# ... existing code ...
def get_huggingface_llm():
key = os.getenv("HUGGINGFACE_API_KEY")
# Check if API key is provided
if not key:
raise RuntimeError("HUGGINGFACE_API_KEY is required for Hugging Face API. Please set your API key.")
# Default to a good open-source model if none specified
model_name = MODEL_NAME if MODEL_NAME else "mistralai/Mistral-7B-Instruct-v0.2"
logger.info(f"Initializing HuggingFace LLM with model: {model_name}")
# Determine appropriate task based on model
task = "text-generation"
if "zephyr" in model_name.lower() or "dialo" in model_name.lower() or "mistral" in model_name.lower():
task = "conversational"
elif "flan" in model_name.lower():
task = "text2text-generation"
elif "t5" in model_name.lower():
task = "text2text-generation"
# Try to initialize the HuggingFace endpoint
try:
return HuggingFaceEndpoint(
repo_id=model_name,
huggingfacehub_api_token=key,
task=task,
temperature=TEMPERATURE,
max_new_tokens=MAX_TOKENS,
)
except Exception as e:
raise RuntimeError(f"Failed to initialize Hugging Face model {model_name}: {str(e)}")
def validate_model_availability(model_name: str, api_key: str):
"""
Validate if the specified model is available for the given API key.
Args:
model_name: Name of the model to check
api_key: API key
Raises:
RuntimeError: If the model is not available
"""
# Simplified validation approach
logger.warning("Model validation is not implemented for all providers. Proceeding with initialization.")
pass
def build_expert_prompt(topic_spec, user_question: str) -> ChatPromptTemplate:
parser = PydanticOutputParser(pydantic_object=TutorResponse)
system_message = f"""
You are Dr. Data, a world-class data science educator with PhDs in CS and Statistics.
You are tutoring a professional on: **{topic_spec.name}**
Context:
- Allowed libraries: {', '.join(topic_spec.allowed_libraries) or 'None'}
- Avoid: {', '.join(topic_spec.banned_topics) or 'Nothing'}
- Style: {topic_spec.style_guide}
Rules:
1. If the question is off-topic (e.g., about web dev in a Pandas session), set is_on_topic=False and give a polite redirect.
2. Always attempt diagnosis: what might the user be confused about?
3. Code must be minimal, correct, and include necessary imports.
4. Cite official documentation when possible.
5. NEVER hallucinate package functions.
6. Output ONLY in the requested JSON format.
{{format_instructions}}
"""
return ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(system_message),
("human", "Question: {question}")
])
def generate_structured_response(topic_key: str, user_question: str) -> TutorResponse:
try:
llm = get_llm()
except Exception as e:
raise RuntimeError(f"Failed to initialize LLM: {str(e)}")
topic_spec = TOPIC_REGISTRY[topic_key]
# Create parser
parser = PydanticOutputParser(pydantic_object=TutorResponse)
# Build prompt with proper variable names
prompt = build_expert_prompt(topic_spec, user_question)
# Create the chain with proper variable binding
chain = prompt.partial(format_instructions=parser.get_format_instructions()) | llm
# Invoke with the question
try:
raw_output = chain.invoke({"question": user_question})
logger.info(f"Raw LLM output: {raw_output.content[:200]}...")
except Exception as e:
error_msg = str(e).lower()
if "401" in error_msg or "unauthorized" in error_msg:
detailed_msg = "API key is invalid or expired. Please check your API key in the sidebar settings."
elif "429" in error_msg or "rate limit" in error_msg:
detailed_msg = "Rate limit exceeded. Please wait a few minutes or check your API plan limits."
elif "connection" in error_msg or "timeout" in error_msg:
detailed_msg = "Network connection issue. Please check your internet connection and try again."
elif "model" in error_msg and "not found" in error_msg:
detailed_msg = f"Model '{MODEL_NAME}' not available. Please select a valid model from the dropdown or check spelling."
else:
detailed_msg = f"Unexpected error: {str(e)}. Please check your model configuration."
raise RuntimeError(f"Failed to get response from LLM: {detailed_msg}")
# Parse and validate
try:
response = parser.parse(raw_output.content)
except Exception as e:
# Try to extract JSON from the response if parsing fails
import re
import json
# Look for JSON in the response
json_match = re.search(r'\{.*\}', raw_output.content, re.DOTALL)
if json_match:
try:
json_str = json_match.group(0)
# Fix common JSON issues
json_str = json_str.replace('\n', '').replace('\t', '')
# Parse and reconstruct response
json_data = json.loads(json_str)
response = TutorResponse(**json_data)
except Exception as json_e:
raise ValueError(f"Failed to parse LLM output as JSON: {json_e}\nOriginal error: {e}\nRaw: {raw_output.content[:500]}...")
else:
# Fallback: retry with stricter prompt or return error
raise ValueError(f"Failed to parse LLM output: {e}\nRaw: {raw_output.content[:500]}...")
return response