File size: 8,707 Bytes
00bd2b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import os
import logging
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from src.config import TOPIC_REGISTRY, MODEL_NAME, TEMPERATURE, MAX_TOKENS
from src.models import TutorResponse

# Conditional imports based on available API
try:
    from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAI
    GOOGLE_API_AVAILABLE = True
except ImportError:
    GOOGLE_API_AVAILABLE = False
    logging.warning("Google Generative AI library not available")

try:
    from langchain_huggingface import HuggingFaceEndpoint
    HUGGINGFACE_API_AVAILABLE = True
except ImportError:
    HUGGINGFACE_API_AVAILABLE = False
    logging.warning("HuggingFace library not available")

try:
    from langchain_openai import ChatOpenAI
    OPENAI_API_AVAILABLE = True
except ImportError:
    OPENAI_API_AVAILABLE = False
    logging.warning("OpenAI library not available")

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def get_llm():
    # Determine which API to use based on environment variables
    api_type = os.getenv("API_TYPE", "huggingface").lower()
    
    if api_type == "google" and GOOGLE_API_AVAILABLE:
        return get_google_llm()
    elif api_type == "openai" and OPENAI_API_AVAILABLE:
        return get_openai_llm()
    elif api_type == "huggingface" and HUGGINGFACE_API_AVAILABLE:
        return get_huggingface_llm()
    else:
        # Fallback to HuggingFace if preferred option is not available
        if HUGGINGFACE_API_AVAILABLE:
            return get_huggingface_llm()
        elif GOOGLE_API_AVAILABLE:
            return get_google_llm()
        elif OPENAI_API_AVAILABLE:
            return get_openai_llm()
        else:
            raise RuntimeError("No suitable LLM API available. Please install one of: langchain-google-genai, langchain-huggingface, langchain-openai")

def get_google_llm():
    key = os.getenv("GOOGLE_API_KEY")
    if not key:
        raise RuntimeError("GOOGLE_API_KEY is required for Google API")
    
    # Ensure model name is set with fallback to a more current default
    model_name = MODEL_NAME if MODEL_NAME else "gemini-1.5-flash"
    
    logger.info(f"Initializing Google LLM with model: {model_name}")
    
    return ChatGoogleGenerativeAI(
        model=model_name,
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        google_api_key=key,
        convert_system_message_to_human=True  # Required for Gemini in LangChain
    )

def get_openai_llm():
    key = os.getenv("OPENAI_API_KEY")
    if not key:
        raise RuntimeError("OPENAI_API_KEY is required for OpenAI API")
    
    # Ensure model name is set
    model_name = MODEL_NAME if MODEL_NAME else "gpt-3.5-turbo"
    
    logger.info(f"Initializing OpenAI LLM with model: {model_name}")
    
    return ChatOpenAI(
        model_name=model_name,
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        openai_api_key=key
    )

# ... existing code ...
def get_huggingface_llm():
    key = os.getenv("HUGGINGFACE_API_KEY")
    
    # Check if API key is provided
    if not key:
        raise RuntimeError("HUGGINGFACE_API_KEY is required for Hugging Face API. Please set your API key.")
    
    # Default to a good open-source model if none specified
    model_name = MODEL_NAME if MODEL_NAME else "mistralai/Mistral-7B-Instruct-v0.2"
    
    logger.info(f"Initializing HuggingFace LLM with model: {model_name}")
    
    # Determine appropriate task based on model
    task = "text-generation"
    if "zephyr" in model_name.lower() or "dialo" in model_name.lower() or "mistral" in model_name.lower():
        task = "conversational"
    elif "flan" in model_name.lower():
        task = "text2text-generation"
    elif "t5" in model_name.lower():
        task = "text2text-generation"
        
    # Try to initialize the HuggingFace endpoint
    try:
        return HuggingFaceEndpoint(
            repo_id=model_name,
            huggingfacehub_api_token=key,
            task=task,
            temperature=TEMPERATURE,
            max_new_tokens=MAX_TOKENS,
            
        )
    except Exception as e:
        raise RuntimeError(f"Failed to initialize Hugging Face model {model_name}: {str(e)}")

def validate_model_availability(model_name: str, api_key: str):
    """
    Validate if the specified model is available for the given API key.
    
    Args:
        model_name: Name of the model to check
        api_key: API key
    
    Raises:
        RuntimeError: If the model is not available
    """
    # Simplified validation approach
    logger.warning("Model validation is not implemented for all providers. Proceeding with initialization.")
    pass

def build_expert_prompt(topic_spec, user_question: str) -> ChatPromptTemplate:
    parser = PydanticOutputParser(pydantic_object=TutorResponse)
    
    system_message = f"""
You are Dr. Data, a world-class data science educator with PhDs in CS and Statistics.
You are tutoring a professional on: **{topic_spec.name}**

Context:
- Allowed libraries: {', '.join(topic_spec.allowed_libraries) or 'None'}
- Avoid: {', '.join(topic_spec.banned_topics) or 'Nothing'}
- Style: {topic_spec.style_guide}

Rules:
1. If the question is off-topic (e.g., about web dev in a Pandas session), set is_on_topic=False and give a polite redirect.
2. Always attempt diagnosis: what might the user be confused about?
3. Code must be minimal, correct, and include necessary imports.
4. Cite official documentation when possible.
5. NEVER hallucinate package functions.
6. Output ONLY in the requested JSON format.

{{format_instructions}}

"""
    
    return ChatPromptTemplate.from_messages([
        SystemMessagePromptTemplate.from_template(system_message),
        ("human", "Question: {question}")
    ])

def generate_structured_response(topic_key: str, user_question: str) -> TutorResponse:
    try:
        llm = get_llm()
    except Exception as e:
        raise RuntimeError(f"Failed to initialize LLM: {str(e)}")
    
    topic_spec = TOPIC_REGISTRY[topic_key]
    
    # Create parser
    parser = PydanticOutputParser(pydantic_object=TutorResponse)
    
    # Build prompt with proper variable names
    prompt = build_expert_prompt(topic_spec, user_question)
    
    # Create the chain with proper variable binding
    chain = prompt.partial(format_instructions=parser.get_format_instructions()) | llm
    
    # Invoke with the question
    try:
        raw_output = chain.invoke({"question": user_question})
        logger.info(f"Raw LLM output: {raw_output.content[:200]}...")
    except Exception as e:
        error_msg = str(e).lower()
        if "401" in error_msg or "unauthorized" in error_msg:
            detailed_msg = "API key is invalid or expired. Please check your API key in the sidebar settings."
        elif "429" in error_msg or "rate limit" in error_msg:
            detailed_msg = "Rate limit exceeded. Please wait a few minutes or check your API plan limits."
        elif "connection" in error_msg or "timeout" in error_msg:
            detailed_msg = "Network connection issue. Please check your internet connection and try again."
        elif "model" in error_msg and "not found" in error_msg:
            detailed_msg = f"Model '{MODEL_NAME}' not available. Please select a valid model from the dropdown or check spelling."
        else:
            detailed_msg = f"Unexpected error: {str(e)}. Please check your model configuration."
        raise RuntimeError(f"Failed to get response from LLM: {detailed_msg}")
    
    # Parse and validate
    try:
        response = parser.parse(raw_output.content)
    except Exception as e:
        # Try to extract JSON from the response if parsing fails
        import re
        import json
        
        # Look for JSON in the response
        json_match = re.search(r'\{.*\}', raw_output.content, re.DOTALL)
        if json_match:
            try:
                json_str = json_match.group(0)
                # Fix common JSON issues
                json_str = json_str.replace('\n', '').replace('\t', '')
                # Parse and reconstruct response
                json_data = json.loads(json_str)
                response = TutorResponse(**json_data)
            except Exception as json_e:
                raise ValueError(f"Failed to parse LLM output as JSON: {json_e}\nOriginal error: {e}\nRaw: {raw_output.content[:500]}...")
        else:
            # Fallback: retry with stricter prompt or return error
            raise ValueError(f"Failed to parse LLM output: {e}\nRaw: {raw_output.content[:500]}...")
    
    return response