VayuChatv2 / app.py
AbhayVG's picture
updated
74f8419 verified
import streamlit as st
import os
import json
import pandas as pd
import random
from os.path import join
from datetime import datetime
from src import (
preprocess_and_load_df,
load_agent,
ask_agent,
decorate_with_code,
show_response,
get_from_user,
load_smart_df,
ask_question,
)
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from streamlit_feedback import streamlit_feedback
from huggingface_hub import HfApi
from datasets import load_dataset, get_dataset_config_info, Dataset
from PIL import Image
import time
import uuid
# Page config with beautiful theme
st.set_page_config(
page_title="VayuChat - AI Air Quality Assistant",
page_icon="🌬️",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for beautiful styling
st.markdown("""
<style>
/* Clean app background */
.stApp {
background-color: #ffffff;
color: #212529;
font-family: 'Segoe UI', sans-serif;
}
/* Sidebar */
[data-testid="stSidebar"] {
background-color: #f8f9fa;
border-right: 1px solid #dee2e6;
padding: 1rem;
}
/* Main title */
.main-title {
text-align: center;
color: #343a40;
font-size: 2.5rem;
font-weight: 700;
margin-bottom: 0.5rem;
}
/* Subtitle */
.subtitle {
text-align: center;
color: #6c757d;
font-size: 1.1rem;
margin-bottom: 1.5rem;
}
/* Instructions */
.instructions {
background-color: #f1f3f5;
border-left: 4px solid #0d6efd;
padding: 1rem;
margin-bottom: 1.5rem;
border-radius: 6px;
color: #495057;
text-align: left;
}
/* Quick prompt buttons */
.quick-prompt-container {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 1.5rem;
padding: 1rem;
background-color: #f8f9fa;
border-radius: 10px;
border: 1px solid #dee2e6;
.quick-prompt-btn {
cursor: pointer;
transition: all 0.2s ease;
white-space: nowrap;
}
.user-message {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 15px 20px;
border-radius: 20px 20px 5px 20px;
margin: 10px 0;
margin-left: auto;
margin-right: 0;
max-width: 80%;
position: relative;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.user-info {
# HuggingFace feedback upload removed for deployment compatibility
st.info("Feedback saved locally (not uploaded).")
font-size: 0.8rem;
opacity: 0.8;
margin-bottom: 5px;
text-align: right;
}
/* Assistant message styling */
.assistant-message {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
color: white;
padding: 15px 20px;
border-radius: 20px 20px 20px 5px;
margin: 10px 0;
margin-left: 0;
margin-right: auto;
max-width: 80%;
position: relative;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.assistant-info {
font-size: 0.8rem;
opacity: 0.8;
margin-bottom: 5px;
}
/* Processing indicator */
.processing-indicator {
background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%);
color: #333;
padding: 15px 20px;
border-radius: 20px 20px 20px 5px;
margin: 10px 0;
margin-left: 0;
margin-right: auto;
max-width: 80%;
position: relative;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
animation: pulse 2s infinite;
}
@keyframes pulse {
0% { opacity: 1; }
50% { opacity: 0.7; }
100% { opacity: 1; }
}
/* Feedback box */
.feedback-section {
background-color: #f8f9fa;
border: 1px solid #dee2e6;
padding: 1rem;
border-radius: 8px;
margin: 1rem 0;
}
/* Success and error messages */
.success-message {
background-color: #d1e7dd;
color: #0f5132;
padding: 1rem;
border-radius: 6px;
border: 1px solid #badbcc;
}
.error-message {
background-color: #f8d7da;
color: #842029;
padding: 1rem;
border-radius: 6px;
border: 1px solid #f5c2c7;
}
/* Chat input */
.stChatInput {
border-radius: 6px;
border: 1px solid #ced4da;
background: #ffffff;
}
/* Button */
.stButton > button {
background-color: #0d6efd;
color: white;
border-radius: 6px;
padding: 0.5rem 1.25rem;
border: none;
font-weight: 600;
transition: background-color 0.2s ease;
}
.stButton > button:hover {
background-color: #0b5ed7;
}
/* Code details styling */
.code-details {
background-color: #f8f9fa;
border: 1px solid #dee2e6;
border-radius: 8px;
padding: 10px;
margin-top: 10px;
}
/* Hide default menu and footer */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
/* Auto scroll */
.main-container {
height: 70vh;
overflow-y: auto;
}
</style>
""", unsafe_allow_html=True)
# Auto-scroll JavaScript
st.markdown("""
<script>
function scrollToBottom() {
setTimeout(function() {
const mainContainer = document.querySelector('.main-container');
if (mainContainer) {
mainContainer.scrollTop = mainContainer.scrollHeight;
}
window.scrollTo(0, document.body.scrollHeight);
}, 100);
}
</script>
""", unsafe_allow_html=True)
# FORCE reload environment variables
load_dotenv(override=True)
# Get API keys
Groq_Token = os.getenv("GROQ_API_KEY")
hf_token = os.getenv("HF_TOKEN")
gemini_token = os.getenv("GEMINI_TOKEN")
models = {
"gpt-oss-20b": "openai/gpt-oss-20b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"llama3.1": "llama-3.1-8b-instant",
"llama3.3": "llama-3.3-70b-versatile",
"deepseek-R1": "deepseek-r1-distill-llama-70b",
"llama4 maverik":"meta-llama/llama-4-maverick-17b-128e-instruct",
"llama4 scout":"meta-llama/llama-4-scout-17b-16e-instruct",
"gemini-pro": "gemini-1.5-pro"
}
self_path = os.path.dirname(os.path.abspath(__file__))
# Initialize session ID for this session
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
def upload_feedback(feedback, error, output, last_prompt, code, status):
"""Enhanced feedback upload function with better logging and error handling"""
try:
if not hf_token or hf_token.strip() == "":
st.warning("⚠️ Cannot upload feedback - HF_TOKEN not available")
return False
# Create comprehensive feedback data
feedback_data = {
"timestamp": datetime.now().isoformat(),
"session_id": st.session_state.session_id,
"feedback_score": feedback.get("score", ""),
"feedback_comment": feedback.get("text", ""),
"user_prompt": last_prompt,
"ai_output": str(output),
"generated_code": code or "",
"error_message": error or "",
"is_image_output": status.get("is_image", False),
"success": not bool(error)
}
# Create unique folder name with timestamp
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
random_id = str(uuid.uuid4())[:8]
folder_name = f"feedback_{timestamp_str}_{random_id}"
# Create markdown feedback file
markdown_content = f"""# VayuChat Feedback Report
## Session Information
- **Timestamp**: {feedback_data['timestamp']}
- **Session ID**: {feedback_data['session_id']}
## User Interaction
**Prompt**: {feedback_data['user_prompt']}
## AI Response
**Output**: {feedback_data['ai_output']}
## Generated Code
```python
{feedback_data['generated_code']}
```
## Technical Details
- **Error Message**: {feedback_data['error_message']}
- **Is Image Output**: {feedback_data['is_image_output']}
- **Success**: {feedback_data['success']}
## User Feedback
- **Score**: {feedback_data['feedback_score']}
- **Comments**: {feedback_data['feedback_comment']}
"""
# Save markdown file locally
markdown_filename = f"{folder_name}.md"
markdown_local_path = f"/tmp/{markdown_filename}"
with open(markdown_local_path, "w", encoding="utf-8") as f:
f.write(markdown_content)
# Upload to Hugging Face
api = HfApi(token=hf_token)
# Upload markdown feedback
api.upload_file(
path_or_fileobj=markdown_local_path,
path_in_repo=f"data/{markdown_filename}",
repo_id="SustainabilityLabIITGN/VayuChat_Feedback",
repo_type="dataset",
)
# Upload image if it exists and is an image output
if status.get("is_image", False) and isinstance(output, str) and os.path.exists(output):
try:
image_filename = f"{folder_name}_plot.png"
api.upload_file(
path_or_fileobj=output,
path_in_repo=f"data/{image_filename}",
repo_id="SustainabilityLabIITGN/VayuChat_Feedback",
repo_type="dataset",
)
except Exception as img_error:
print(f"Error uploading image: {img_error}")
# Clean up local files
if os.path.exists(markdown_local_path):
os.remove(markdown_local_path)
st.success("πŸŽ‰ Feedback uploaded successfully!")
return True
except Exception as e:
st.error(f"❌ Error uploading feedback: {e}")
print(f"Feedback upload error: {e}")
return False
# --- MODERN HEADER & LAYOUT ---
st.markdown("""
<style>
.main-title {text-align:center; color:#2d3748; font-size:2.8rem; font-weight:800; margin-bottom:0.2rem; letter-spacing:1px;}
.subtitle {text-align:center; color:#4b5563; font-size:1.2rem; margin-bottom:1.2rem;}
.instructions {background:#f1f5f9; border-left:4px solid #6366f1; padding:1rem; margin-bottom:1.2rem; border-radius:8px; color:#374151;}
.chat-bubble-user {background:linear-gradient(90deg,#6366f1 0%,#818cf8 100%);color:white;padding:14px 20px;border-radius:20px 20px 5px 20px;margin:10px 0 10px auto;max-width:80%;box-shadow:0 2px 10px rgba(0,0,0,0.08);}
.chat-bubble-ai {background:linear-gradient(90deg,#f472b6 0%,#fbbf24 100%);color:white;padding:14px 20px;border-radius:20px 20px 20px 5px;margin:10px auto 10px 0;max-width:80%;box-shadow:0 2px 10px rgba(0,0,0,0.08);}
.quick-prompt-btn {background:linear-gradient(90deg,#6366f1 0%,#818cf8 100%);color:white;border:none;padding:10px 20px;border-radius:22px;font-size:1rem;cursor:pointer;transition:all 0.2s;font-weight:500;margin:0 8px 8px 0;}
.quick-prompt-btn:hover {background:#6366f1;transform:scale(1.05);}
.chat-input-bar {background:#fff;border-radius:12px;box-shadow:0 2px 8px rgba(0,0,0,0.04);padding:1rem 1.5rem;margin-top:1.5rem;}
.sidebar-section {margin-bottom:1.5rem;}
.sidebar-title {font-weight:700;font-size:1.1rem;margin-bottom:0.5rem;}
.sidebar-data {font-size:0.95rem;color:#555;}
.feedback-section {background:#f3f4f6;border:1px solid #e5e7eb;padding:1rem;border-radius:10px;margin:1rem 0;}
.main-container {height:65vh;overflow-y:auto;padding-bottom:1rem;}
</style>
""", unsafe_allow_html=True)
st.markdown("<h1 class='main-title'>🌬️ VayuChat</h1>", unsafe_allow_html=True)
st.markdown("<div class='subtitle'><strong>AI-Powered Air Quality, Funding & Population Insights</strong><br>Simplifying analysis using conversational AI.</div>", unsafe_allow_html=True)
st.markdown("<div class='instructions'><strong>How to Use:</strong><br>Select a model from the sidebar and ask questions in the chat. Use quick prompts for common queries. Switch models and rerun easily!</div>", unsafe_allow_html=True)
os.environ["PANDASAI_API_KEY"] = "$2a$10$gbmqKotzJOnqa7iYOun8eO50TxMD/6Zw1pLI2JEoqncwsNx4XeBS2"
# --- LOAD ALL DATAFRAMES ---
try:
df = preprocess_and_load_df(join(self_path, "Data.csv"))
import pickle
ncap_data = pd.read_pickle(join(self_path, "ncap_funding_data.pkl"))
states_data = pd.read_pickle(join(self_path, "states_data.pkl"))
st.success("βœ… Data loaded successfully!")
except Exception as e:
st.error(f"❌ Error loading data: {e}")
st.stop()
inference_server = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
image_path = "IITGN_Logo.png"
# --- MODERN SIDEBAR ---
with st.sidebar:
st.image(image_path, width=120)
st.markdown("<div class='sidebar-section sidebar-title'>πŸ€– Model Selection</div>", unsafe_allow_html=True)
available_models = []
model_names = list(models.keys())
groq_models = [m for m in model_names if "gemini" not in m]
gemini_models = [m for m in model_names if "gemini" in m]
if Groq_Token and Groq_Token.strip():
available_models.extend(groq_models)
if gemini_token and gemini_token.strip():
available_models.extend(gemini_models)
if not available_models:
st.error("❌ No API keys available! Please set up your API keys in the .env file")
st.stop()
model_name = st.selectbox("Choose your AI assistant:", available_models)
model_descriptions = {
"llama3.1": "πŸ¦™ Fast and efficient for general queries",
"llama3.3": "πŸ¦™ Most advanced LLaMA model for complex reasoning",
"mistral": "⚑ Balanced performance and speed",
"gemma": "πŸ’Ž Google's lightweight model",
"gemini-pro": "🧠 Google's most powerful model",
"gpt-oss-20b": "πŸ“˜ OpenAI's compact open-weight GPT for everyday tasks",
"gpt-oss-120b": "πŸ“š OpenAI's massive open-weight GPT for nuanced responses",
"deepseek-R1": "πŸ” DeepSeek's distilled LLaMA model for efficient reasoning",
"llama4 maverik": "πŸš€ Meta's LLaMA 4 Maverick β€” high-performance instruction model",
"llama4 scout": "πŸ›°οΈ Meta's LLaMA 4 Scout β€” optimized for adaptive reasoning"
}
if model_name in model_descriptions:
st.info(model_descriptions[model_name])
st.markdown("<div class='sidebar-section sidebar-title'>πŸ“Š Data Sources</div>", unsafe_allow_html=True)
st.markdown("<div class='sidebar-data'>β€’ <b>Data.csv</b>: Air quality<br>β€’ <b>ncap_funding_data.pkl</b>: Funding<br>β€’ <b>states_data.pkl</b>: Population</div>", unsafe_allow_html=True)
st.markdown("<div class='sidebar-section sidebar-title'>🧹 Utilities</div>", unsafe_allow_html=True)
if st.button("Clear Chat"):
st.session_state.responses = []
st.session_state.processing = False
st.session_state.session_id = str(uuid.uuid4())
st.rerun()
st.markdown(f"<div class='sidebar-section sidebar-title'>Session ID</div><div class='sidebar-data'>{st.session_state.session_id[:8]}...</div>", unsafe_allow_html=True)
# Load quick prompts
questions = []
questions_file = join(self_path, "questions.txt")
if os.path.exists(questions_file):
try:
with open(questions_file, 'r', encoding='utf-8') as f:
content = f.read()
questions = [q.strip() for q in content.split("\n") if q.strip()]
print(f"Loaded {len(questions)} quick prompts") # Debug
except Exception as e:
st.error(f"Error loading questions: {e}")
questions = []
# Add some default prompts if file doesn't exist or is empty
if not questions:
questions = [
"What is the average PM2.5 level in the dataset?",
"Show me the air quality trend over time",
"Which pollutant has the highest concentration?",
"Create a correlation plot between different pollutants",
"What are the peak pollution hours?",
"Compare weekday vs weekend pollution levels"
]
# --- MODERN QUICK PROMPTS ---
st.markdown("<div style='margin-bottom:0.5rem;'><b>πŸ’­ Quick Prompts</b></div>", unsafe_allow_html=True)
selected_prompt = None
prompt_buttons = []
for i, q in enumerate(questions):
if st.button(q, key=f"quick_{i}", help=q, use_container_width=True):
selected_prompt = q
# Initialize chat history and processing state
if "responses" not in st.session_state:
st.session_state.responses = []
if "processing" not in st.session_state:
st.session_state.processing = False
def show_custom_response(response):
"""Custom response display function"""
role = response.get("role", "assistant")
content = response.get("content", "")
if role == "user":
st.markdown(f"""
<div class='user-message'>
<div class='user-info'>You</div>
{content}
</div>
""", unsafe_allow_html=True)
elif role == "assistant":
st.markdown(f"""
<div class='assistant-message'>
<div class='assistant-info'>πŸ€– VayuChat</div>
{content if isinstance(content, str) else str(content)}
</div>
""", unsafe_allow_html=True)
# Show generated code if available
if response.get("gen_code"):
with st.expander("πŸ“‹ View Generated Code"):
st.code(response["gen_code"], language="python")
# Try to display image if content is a file path
try:
if isinstance(content, str) and (content.endswith('.png') or content.endswith('.jpg')):
if os.path.exists(content):
st.image(content)
return {"is_image": True}
except:
pass
return {"is_image": False}
def show_processing_indicator(model_name, question):
"""Show processing indicator"""
st.markdown(f"""
<div class='processing-indicator'>
<div class='assistant-info'>πŸ€– VayuChat β€’ Processing with {model_name}</div>
<strong>Question:</strong> {question}<br>
<em>πŸ”„ Generating response...</em>
</div>
""", unsafe_allow_html=True)
# --- MODERN MAIN CHAT AREA ---
st.markdown("<div class='main-container'>", unsafe_allow_html=True)
for response in st.session_state.responses:
if response["role"] == "user":
st.markdown(f"<div class='chat-bubble-user'>{response['content']}</div>", unsafe_allow_html=True)
else:
st.markdown(f"<div class='chat-bubble-ai'>{response['content']}</div>", unsafe_allow_html=True)
# Show code details if available
if response.get("gen_code"):
with st.expander("πŸ“‹ View Generated Code"):
st.code(response["gen_code"], language="python")
# Feedback section
if "feedback" in response:
feedback_data = response["feedback"]
feedback_text = feedback_data.get('text', '')
feedback_score = feedback_data.get('score', '')
feedback_str = f"- {feedback_text}" if feedback_text else ''
st.markdown(f"<div class='feedback-section'><strong>πŸ“ Your Feedback:</strong> {feedback_score} {feedback_str}</div>", unsafe_allow_html=True)
else:
st.markdown("<div class='feedback-section'><b>How was this response?</b>", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
thumbs_up = st.button("πŸ‘ Helpful", key=f"fb_up_{id(response)}")
with col2:
thumbs_down = st.button("πŸ‘Ž Not Helpful", key=f"fb_down_{id(response)}")
if thumbs_up or thumbs_down:
thumbs = "πŸ‘ Helpful" if thumbs_up else "πŸ‘Ž Not Helpful"
comments = st.text_area("πŸ’¬ Tell us more (optional):", key=f"fb_comments_{id(response)}", placeholder="What could be improved? Any suggestions?", max_chars=500)
if st.button("πŸš€ Submit Feedback", key=f"fb_submit_{id(response)}"):
feedback = {"score": thumbs, "text": comments}
if upload_feedback(feedback, response.get("error", ""), response.get("content", ""), response.get("last_prompt", ""), response.get("gen_code", ""), {}):
response["feedback"] = feedback
time.sleep(1)
st.rerun()
st.markdown("</div>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# --- MODERN CHAT INPUT & RERUN ---
st.markdown("<div class='chat-input-bar'>", unsafe_allow_html=True)
rerun_col, input_col = st.columns([1, 8])
with rerun_col:
if st.button("πŸ”„ Rerun Last", help="Rerun the last question with the selected model"):
if "last_prompt" in st.session_state and st.session_state["last_prompt"]:
st.session_state.processing = True
st.session_state.current_model = model_name
st.session_state.current_question = st.session_state["last_prompt"]
st.rerun()
with input_col:
prompt = st.text_input("Ask about air quality, funding, or population...", value=selected_prompt or "", key="main_chat")
st.markdown("</div>", unsafe_allow_html=True)
# Handle new queries
if prompt and not st.session_state.get("processing"):
if "last_prompt" in st.session_state:
last_prompt = st.session_state["last_prompt"]
last_model_name = st.session_state.get("last_model_name", "")
if (prompt == last_prompt) and (model_name == last_model_name):
prompt = None
if prompt:
user_response = get_from_user(prompt)
st.session_state.responses.append(user_response)
st.session_state.processing = True
st.session_state.current_model = model_name
st.session_state.current_question = prompt
st.rerun()
# Process the question if we're in processing state
if st.session_state.get("processing"):
prompt = st.session_state.get("current_question")
model_name = st.session_state.get("current_model")
try:
from src import SYSTEM_PROMPT
agent = load_agent(df, SYSTEM_PROMPT, inference_server, name=model_name)
response = ask_agent(agent, prompt)
if not isinstance(response, dict):
response = {
"role": "assistant",
"content": "❌ Error: Invalid response format",
"gen_code": "",
"ex_code": "",
"last_prompt": prompt,
"error": "Invalid response format"
}
response.setdefault("role", "assistant")
response.setdefault("content", "No content generated")
response.setdefault("gen_code", "")
response.setdefault("ex_code", "")
response.setdefault("last_prompt", prompt)
response.setdefault("error", None)
except Exception as e:
response = {
"role": "assistant",
"content": f"Sorry, I encountered an error: {str(e)}",
"gen_code": "",
"ex_code": "",
"last_prompt": prompt,
"error": str(e)
}
st.session_state.responses.append(response)
st.session_state["last_prompt"] = prompt
st.session_state["last_model_name"] = model_name
st.session_state.processing = False
# Clear processing state
if "current_model" in st.session_state:
del st.session_state.current_model
if "current_question" in st.session_state:
del st.session_state.current_question
st.rerun()
# Auto-scroll to bottom
if st.session_state.responses:
st.markdown("<script>scrollToBottom();</script>", unsafe_allow_html=True)
# Beautiful sidebar footer
# with st.sidebar:
# st.markdown("---")
# st.markdown("""
# <div class='contact-section'>
# <h4>πŸ“„ Paper on VayuChat</h4>
# <p>Learn more about VayuChat in our <a href='https://arxiv.org/abs/2411.12760' target='_blank'>Research Paper</a>.</p>
# </div>
# """, unsafe_allow_html=True)
# Statistics (if logging is enabled)
if hf_token and hf_token.strip():
st.markdown("### πŸ“ˆ Session Stats")
total_interactions = len([r for r in st.session_state.get("responses", []) if r.get("role") == "assistant"])
st.metric("Interactions", total_interactions)
feedbacks_given = len([r for r in st.session_state.get("responses", []) if r.get("role") == "assistant" and "feedback" in r])
st.metric("Feedbacks Given", feedbacks_given)
# Footer
st.markdown("""
<div style='text-align: center; margin-top: 3rem; padding: 2rem; background: rgba(255,255,255,0.1); border-radius: 15px;'>
<h3>🌍 Together for Cleaner Air</h3>
<p>VayuChat - Empowering environmental awareness through AI</p>
<small>Β© 2025 IIT Gandhinagar Sustainability Lab</small>
</div>
""", unsafe_allow_html=True)