Gyan.AI / backend /app /ai_tutor.py
cryogenic22's picture
Rename src/ai_tutor.py to backend/app/ai_tutor.py
c7b32a3 verified
import streamlit as st
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
import base64
from typing import Dict, List, Optional
import os
from dotenv import load_dotenv
class EnhancedAITutor:
def __init__(self):
load_dotenv()
self.initialize_models()
self.initialize_session_state()
self.load_avatar_assets()
@st.cache_resource
def initialize_models(self):
"""Initialize AI models with caching"""
# Initialize main LLM for tutoring
self.tutor_model = pipeline(
"text-generation",
model="facebook/opt-1.3b", # You can replace with your preferred model
device=0 if torch.cuda.is_available() else -1
)
# Initialize sentiment analysis for engagement tracking
self.sentiment_model = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=0 if torch.cuda.is_available() else -1
)
def initialize_session_state(self):
"""Initialize session state for the tutor"""
if 'tutor_context' not in st.session_state:
st.session_state.tutor_context = {
'chat_history': [],
'current_topic': None,
'difficulty_level': 'intermediate',
'learning_style': 'interactive',
'engagement_metrics': []
}
def load_avatar_assets(self):
"""Load avatar assets and animations"""
# In a real implementation, you would load actual avatar assets
self.avatar_states = {
'neutral': """
<svg width="100" height="100" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="#4A90E2"/>
<circle cx="35" cy="40" r="5" fill="white"/>
<circle cx="65" cy="40" r="5" fill="white"/>
<path d="M 30 60 Q 50 70 70 60" stroke="white" fill="none" stroke-width="3"/>
</svg>
""",
'thinking': """
<svg width="100" height="100" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="#4A90E2"/>
<circle cx="35" cy="40" r="5" fill="white"/>
<circle cx="65" cy="40" r="5" fill="white"/>
<path d="M 30 65 Q 50 65 70 65" stroke="white" fill="none" stroke-width="3"/>
</svg>
""",
'happy': """
<svg width="100" height="100" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="#4A90E2"/>
<circle cx="35" cy="40" r="5" fill="white"/>
<circle cx="65" cy="40" r="5" fill="white"/>
<path d="M 30 60 Q 50 80 70 60" stroke="white" fill="none" stroke-width="3"/>
</svg>
"""
}
def display_avatar(self, state: str = 'neutral'):
"""Display the AI tutor avatar"""
st.markdown(f"""
<div style="display: flex; justify-content: center; margin: 20px 0;">
{self.avatar_states.get(state, self.avatar_states['neutral'])}
</div>
""", unsafe_allow_html=True)
def generate_response(self, user_input: str) -> str:
"""Generate contextualized response using the LLM"""
# Analyze user sentiment
sentiment = self.sentiment_model(user_input)[0]
# Build context-aware prompt
context = st.session_state.tutor_context
prompt = self.build_prompt(user_input, context)
# Generate response
response = self.tutor_model(prompt,
max_length=200,
num_return_sequences=1)[0]['generated_text']
# Update engagement metrics
self.update_engagement_metrics(user_input, response, sentiment)
return response
def build_prompt(self, user_input: str, context: Dict) -> str:
"""Build a context-aware prompt for the tutor"""
# Include current topic and difficulty level
topic_context = f"Current topic: {context['current_topic']}\n" if context['current_topic'] else ""
# Include recent chat history for context
chat_context = "\n".join([
f"Student: {msg['content']}" if msg['role'] == 'user' else f"Tutor: {msg['content']}"
for msg in context['chat_history'][-3:] # Last 3 messages
])
# Build the final prompt
prompt = f"""
You are an educational AI tutor. Your goal is to help students learn effectively.
{topic_context}
Difficulty level: {context['difficulty_level']}
Learning style: {context['learning_style']}
Previous conversation:
{chat_context}
Student: {user_input}
Tutor:"""
return prompt
def update_engagement_metrics(self, user_input: str, response: str, sentiment: Dict):
"""Update student engagement metrics"""
context = st.session_state.tutor_context
context['engagement_metrics'].append({
'timestamp': datetime.now().isoformat(),
'sentiment': sentiment['label'],
'sentiment_score': sentiment['score'],
'interaction_length': len(user_input)
})
def display_chat_interface(self):
"""Display the enhanced chat interface with avatar"""
st.header("AI Tutor")
# Display avatar
self.display_avatar(state='neutral')
# Topic selection
topics = [None, 'Physics', 'Mathematics', 'Computer Science', 'Artificial Intelligence']
selected_topic = st.selectbox(
"Select Topic",
topics,
format_func=lambda x: 'All Topics' if x is None else x,
key="topic_selector"
)
if selected_topic != st.session_state.tutor_context['current_topic']:
st.session_state.tutor_context['current_topic'] = selected_topic
# Display chat container
chat_container = st.container()
with chat_container:
# Display chat history with avatar states
for message in st.session_state.tutor_context['chat_history']:
with st.chat_message(message["role"]):
st.write(message["content"])
if message["role"] == "assistant":
self.display_avatar(state='happy')
# Chat input
if prompt := st.chat_input("Ask your question"):
# Show thinking avatar
self.display_avatar(state='thinking')
# Add user message
st.session_state.tutor_context['chat_history'].append({
"role": "user",
"content": prompt
})
# Generate and display AI response
response = self.generate_response(prompt)
# Add AI response
st.session_state.tutor_context['chat_history'].append({
"role": "assistant",
"content": response
})
# Show happy avatar and rerun
self.display_avatar(state='happy')
st.rerun()
def display_learning_metrics(self):
"""Display learning progress and engagement metrics"""
with st.sidebar:
st.subheader("Learning Metrics")
# Engagement score
metrics = st.session_state.tutor_context['engagement_metrics']
if metrics:
avg_sentiment = sum(m['sentiment_score'] for m in metrics) / len(metrics)
st.metric(
"Engagement Score",
f"{avg_sentiment:.2f}",
delta="0.1" if avg_sentiment > 0.5 else "-0.1"
)
# Interaction stats
if st.session_state.tutor_context['chat_history']:
st.metric(
"Questions Asked",
len([m for m in st.session_state.tutor_context['chat_history']
if m['role'] == 'user'])
)
# Topic focus
if st.session_state.tutor_context['current_topic']:
st.info(f"Current focus: {st.session_state.tutor_context['current_topic']}")
def save_chat_history(self):
"""Save chat history for future reference"""
# In a real implementation, this would save to a database
pass
def load_chat_history(self):
"""Load previous chat history"""
# In a real implementation, this would load from a database
pass