DasariHarshitha's picture
Update app.py
3a525e0 verified
raw
history blame
3.69 kB
import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
# ---------------------------
# Streamlit UI Config
# ---------------------------
st.set_page_config(page_title="AI Data Science Tutor", layout="wide")
st.title("πŸ“Š AI Conversational Data Science Tutor")
# ---------------------------
# Sidebar - Settings
# ---------------------------
st.sidebar.header("βš™οΈ Settings")
mode = st.sidebar.radio(
"Choose Tutor Mode:",
("Dummy Tutor (No API Key)", "OpenAI Tutor (API Key Required)")
)
if mode == "OpenAI Tutor (API Key Required)":
openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type="password")
else:
openai_api_key = None
# ---------------------------
# Initialize Memory
# ---------------------------
if "memory" not in st.session_state:
st.session_state.memory = ConversationBufferMemory(return_messages=True)
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# ---------------------------
# Dummy Tutor Response Logic
# ---------------------------
def dummy_tutor_response(user_query):
q = user_query.lower()
if "regression" in q:
return "πŸ“Š Regression is a supervised ML technique used to predict continuous values."
elif "classification" in q:
return "πŸ” Classification predicts categorical labels, e.g., spam vs not spam."
elif "neural" in q or "deep learning" in q:
return "🧠 Neural Networks consist of layers of neurons that learn patterns from data."
elif "pca" in q:
return "πŸ“‰ PCA reduces dimensions while preserving variance."
elif "accuracy" in q or "precision" in q or "recall" in q:
return "βœ… Accuracy = correct predictions / total. For imbalanced data, use precision, recall, or F1-score."
elif "clustering" in q:
return "πŸ“Œ Clustering groups similar points without labels (unsupervised learning)."
elif "overfitting" in q:
return "⚠️ Overfitting means the model memorizes data instead of generalizing."
else:
return f"πŸ€” I didn’t fully get that. Can you rephrase your Data Science question? (You asked: {user_query})"
# ---------------------------
# Get AI Tutor Response
# ---------------------------
def get_tutor_response(user_query):
if mode == "Dummy Tutor (No API Key)":
return dummy_tutor_response(user_query)
if mode == "OpenAI Tutor (API Key Required)" and openai_api_key:
llm = ChatOpenAI(
model="gpt-4o-mini", # You can also try "gpt-4o" or "gpt-4"
openai_api_key=openai_api_key,
temperature=0.5
)
# Add user message to memory
st.session_state.memory.chat_memory.add_user_message(user_query)
# Generate response
response = llm(st.session_state.memory.chat_memory.messages)
# Add AI message to memory
st.session_state.memory.chat_memory.add_ai_message(response.content)
return response.content
return "⚠️ Please provide your OpenAI API key in the sidebar."
# ---------------------------
# Chat UI
# ---------------------------
user_query = st.chat_input("Ask me a Data Science question...")
if user_query:
response = get_tutor_response(user_query)
st.session_state.chat_history.append(("You", user_query))
st.session_state.chat_history.append(("Tutor", response))
# Display Chat History
for sender, msg in st.session_state.chat_history:
if sender == "You":
st.markdown(f"**πŸ‘©β€πŸ’» {sender}:** {msg}")
else:
st.markdown(f"**πŸ€– {sender}:** {msg}")