File size: 8,707 Bytes
00bd2b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import os
import logging
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from src.config import TOPIC_REGISTRY, MODEL_NAME, TEMPERATURE, MAX_TOKENS
from src.models import TutorResponse
# Conditional imports based on available API
try:
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAI
GOOGLE_API_AVAILABLE = True
except ImportError:
GOOGLE_API_AVAILABLE = False
logging.warning("Google Generative AI library not available")
try:
from langchain_huggingface import HuggingFaceEndpoint
HUGGINGFACE_API_AVAILABLE = True
except ImportError:
HUGGINGFACE_API_AVAILABLE = False
logging.warning("HuggingFace library not available")
try:
from langchain_openai import ChatOpenAI
OPENAI_API_AVAILABLE = True
except ImportError:
OPENAI_API_AVAILABLE = False
logging.warning("OpenAI library not available")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_llm():
# Determine which API to use based on environment variables
api_type = os.getenv("API_TYPE", "huggingface").lower()
if api_type == "google" and GOOGLE_API_AVAILABLE:
return get_google_llm()
elif api_type == "openai" and OPENAI_API_AVAILABLE:
return get_openai_llm()
elif api_type == "huggingface" and HUGGINGFACE_API_AVAILABLE:
return get_huggingface_llm()
else:
# Fallback to HuggingFace if preferred option is not available
if HUGGINGFACE_API_AVAILABLE:
return get_huggingface_llm()
elif GOOGLE_API_AVAILABLE:
return get_google_llm()
elif OPENAI_API_AVAILABLE:
return get_openai_llm()
else:
raise RuntimeError("No suitable LLM API available. Please install one of: langchain-google-genai, langchain-huggingface, langchain-openai")
def get_google_llm():
key = os.getenv("GOOGLE_API_KEY")
if not key:
raise RuntimeError("GOOGLE_API_KEY is required for Google API")
# Ensure model name is set with fallback to a more current default
model_name = MODEL_NAME if MODEL_NAME else "gemini-1.5-flash"
logger.info(f"Initializing Google LLM with model: {model_name}")
return ChatGoogleGenerativeAI(
model=model_name,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
google_api_key=key,
convert_system_message_to_human=True # Required for Gemini in LangChain
)
def get_openai_llm():
key = os.getenv("OPENAI_API_KEY")
if not key:
raise RuntimeError("OPENAI_API_KEY is required for OpenAI API")
# Ensure model name is set
model_name = MODEL_NAME if MODEL_NAME else "gpt-3.5-turbo"
logger.info(f"Initializing OpenAI LLM with model: {model_name}")
return ChatOpenAI(
model_name=model_name,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
openai_api_key=key
)
# ... existing code ...
def get_huggingface_llm():
key = os.getenv("HUGGINGFACE_API_KEY")
# Check if API key is provided
if not key:
raise RuntimeError("HUGGINGFACE_API_KEY is required for Hugging Face API. Please set your API key.")
# Default to a good open-source model if none specified
model_name = MODEL_NAME if MODEL_NAME else "mistralai/Mistral-7B-Instruct-v0.2"
logger.info(f"Initializing HuggingFace LLM with model: {model_name}")
# Determine appropriate task based on model
task = "text-generation"
if "zephyr" in model_name.lower() or "dialo" in model_name.lower() or "mistral" in model_name.lower():
task = "conversational"
elif "flan" in model_name.lower():
task = "text2text-generation"
elif "t5" in model_name.lower():
task = "text2text-generation"
# Try to initialize the HuggingFace endpoint
try:
return HuggingFaceEndpoint(
repo_id=model_name,
huggingfacehub_api_token=key,
task=task,
temperature=TEMPERATURE,
max_new_tokens=MAX_TOKENS,
)
except Exception as e:
raise RuntimeError(f"Failed to initialize Hugging Face model {model_name}: {str(e)}")
def validate_model_availability(model_name: str, api_key: str):
"""
Validate if the specified model is available for the given API key.
Args:
model_name: Name of the model to check
api_key: API key
Raises:
RuntimeError: If the model is not available
"""
# Simplified validation approach
logger.warning("Model validation is not implemented for all providers. Proceeding with initialization.")
pass
def build_expert_prompt(topic_spec, user_question: str) -> ChatPromptTemplate:
parser = PydanticOutputParser(pydantic_object=TutorResponse)
system_message = f"""
You are Dr. Data, a world-class data science educator with PhDs in CS and Statistics.
You are tutoring a professional on: **{topic_spec.name}**
Context:
- Allowed libraries: {', '.join(topic_spec.allowed_libraries) or 'None'}
- Avoid: {', '.join(topic_spec.banned_topics) or 'Nothing'}
- Style: {topic_spec.style_guide}
Rules:
1. If the question is off-topic (e.g., about web dev in a Pandas session), set is_on_topic=False and give a polite redirect.
2. Always attempt diagnosis: what might the user be confused about?
3. Code must be minimal, correct, and include necessary imports.
4. Cite official documentation when possible.
5. NEVER hallucinate package functions.
6. Output ONLY in the requested JSON format.
{{format_instructions}}
"""
return ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(system_message),
("human", "Question: {question}")
])
def generate_structured_response(topic_key: str, user_question: str) -> TutorResponse:
try:
llm = get_llm()
except Exception as e:
raise RuntimeError(f"Failed to initialize LLM: {str(e)}")
topic_spec = TOPIC_REGISTRY[topic_key]
# Create parser
parser = PydanticOutputParser(pydantic_object=TutorResponse)
# Build prompt with proper variable names
prompt = build_expert_prompt(topic_spec, user_question)
# Create the chain with proper variable binding
chain = prompt.partial(format_instructions=parser.get_format_instructions()) | llm
# Invoke with the question
try:
raw_output = chain.invoke({"question": user_question})
logger.info(f"Raw LLM output: {raw_output.content[:200]}...")
except Exception as e:
error_msg = str(e).lower()
if "401" in error_msg or "unauthorized" in error_msg:
detailed_msg = "API key is invalid or expired. Please check your API key in the sidebar settings."
elif "429" in error_msg or "rate limit" in error_msg:
detailed_msg = "Rate limit exceeded. Please wait a few minutes or check your API plan limits."
elif "connection" in error_msg or "timeout" in error_msg:
detailed_msg = "Network connection issue. Please check your internet connection and try again."
elif "model" in error_msg and "not found" in error_msg:
detailed_msg = f"Model '{MODEL_NAME}' not available. Please select a valid model from the dropdown or check spelling."
else:
detailed_msg = f"Unexpected error: {str(e)}. Please check your model configuration."
raise RuntimeError(f"Failed to get response from LLM: {detailed_msg}")
# Parse and validate
try:
response = parser.parse(raw_output.content)
except Exception as e:
# Try to extract JSON from the response if parsing fails
import re
import json
# Look for JSON in the response
json_match = re.search(r'\{.*\}', raw_output.content, re.DOTALL)
if json_match:
try:
json_str = json_match.group(0)
# Fix common JSON issues
json_str = json_str.replace('\n', '').replace('\t', '')
# Parse and reconstruct response
json_data = json.loads(json_str)
response = TutorResponse(**json_data)
except Exception as json_e:
raise ValueError(f"Failed to parse LLM output as JSON: {json_e}\nOriginal error: {e}\nRaw: {raw_output.content[:500]}...")
else:
# Fallback: retry with stricter prompt or return error
raise ValueError(f"Failed to parse LLM output: {e}\nRaw: {raw_output.content[:500]}...")
return response |