aoe-demo / src /streamlit_app.py
ItCodinTime's picture
Replace spiral demo with LLM comparison interface
7cb9cc4 verified
import streamlit as st
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import traceback
from typing import Optional
# Configure the page
st.set_page_config(
page_title="LLM Comparison: GPT-4 vs Gemini vs AOE",
page_icon="βš”οΈ",
layout="wide"
)
def load_aoe_model():
"""Load the AoE model and tokenizer from outputs/student/ directory"""
model_path = "outputs/student/"
try:
if not os.path.exists(model_path):
st.error(f"Model directory '{model_path}' not found. Please ensure the model files are present.")
return None, None
# Check if required files exist
required_files = ["config.json", "pytorch_model.bin", "tokenizer.json"]
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
if missing_files:
st.warning(f"Some model files may be missing: {missing_files}. Attempting to load anyway...")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
return model, tokenizer
except Exception as e:
st.error(f"Error loading AoE model: {str(e)}")
st.text(f"Traceback: {traceback.format_exc()}")
return None, None
def generate_aoe_response(model, tokenizer, prompt, max_length=512):
"""Generate response from the AoE model"""
try:
# Tokenize input
inputs = tokenizer.encode(prompt, return_tensors="pt")
# Move to same device as model if CUDA is available
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
inputs = inputs.cuda()
# Generate response
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=len(inputs[0]) + max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the input prompt from the response
if response.startswith(prompt):
response = response[len(prompt):].strip()
return response
except Exception as e:
return f"Error generating AoE response: {str(e)}"
def query_gpt4_api(prompt: str, api_key: Optional[str] = None) -> str:
"""Query GPT-4 API (placeholder - requires API key)"""
if not api_key:
return "❌ GPT-4 API key not configured. Please add your OpenAI API key to use GPT-4."
try:
# This is a placeholder implementation - would need actual OpenAI API integration
return "πŸ€– GPT-4 response would appear here with proper API configuration."
except Exception as e:
return f"Error querying GPT-4: {str(e)}"
def query_gemini_api(prompt: str, api_key: Optional[str] = None) -> str:
"""Query Gemini API (placeholder - requires API key)"""
if not api_key:
return "❌ Gemini API key not configured. Please add your Google API key to use Gemini."
try:
# This is a placeholder implementation - would need actual Google Gemini API integration
return "πŸ€– Gemini response would appear here with proper API configuration."
except Exception as e:
return f"Error querying Gemini: {str(e)}"
def main():
st.title("βš”οΈ LLM Comparison: GPT-4 vs Gemini vs AOE")
st.markdown("Compare responses from three different language models side by side.")
# Initialize session state for model caching
if 'aoe_model' not in st.session_state:
st.session_state.aoe_model = None
st.session_state.aoe_tokenizer = None
st.session_state.aoe_loaded = False
# Load AOE model on first run
if not st.session_state.aoe_loaded:
with st.spinner("Loading AOE model from outputs/student/..."):
model, tokenizer = load_aoe_model()
if model is not None and tokenizer is not None:
st.session_state.aoe_model = model
st.session_state.aoe_tokenizer = tokenizer
st.session_state.aoe_loaded = True
st.success("βœ… AOE model loaded successfully!")
else:
st.error("❌ Failed to load AOE model. Check error messages above.")
# Configuration section
st.markdown("---")
st.subheader("πŸ”§ Configuration")
col1, col2, col3 = st.columns(3)
with col1:
openai_api_key = st.text_input(
"OpenAI API Key (for GPT-4)",
type="password",
help="Enter your OpenAI API key to enable GPT-4 responses"
)
with col2:
google_api_key = st.text_input(
"Google API Key (for Gemini)",
type="password",
help="Enter your Google API key to enable Gemini responses"
)
with col3:
max_length = st.slider(
"Max Response Length",
min_value=100,
max_value=1000,
value=512,
step=50,
help="Maximum length for generated responses"
)
# Main comparison interface
st.markdown("---")
st.subheader("πŸ’¬ Compare LLM Responses")
# User input
user_prompt = st.text_area(
"Enter your prompt:",
placeholder="Type your prompt here to compare responses from all three models...",
height=120,
help="Enter a prompt to see how different LLMs respond"
)
# Generate responses button
if st.button("πŸš€ Generate All Responses", type="primary"):
if not user_prompt.strip():
st.warning("Please enter a prompt first.")
else:
# Create three columns for side-by-side comparison
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("### πŸ€– GPT-4")
with st.spinner("Generating GPT-4 response..."):
gpt4_response = query_gpt4_api(user_prompt, openai_api_key)
st.markdown("**Response:**")
st.write(gpt4_response)
with col2:
st.markdown("### 🌟 Gemini")
with st.spinner("Generating Gemini response..."):
gemini_response = query_gemini_api(user_prompt, google_api_key)
st.markdown("**Response:**")
st.write(gemini_response)
with col3:
st.markdown("### 🏰 AOE (Local)")
if st.session_state.aoe_loaded:
with st.spinner("Generating AOE response..."):
aoe_response = generate_aoe_response(
st.session_state.aoe_model,
st.session_state.aoe_tokenizer,
user_prompt,
max_length
)
st.markdown("**Response:**")
st.write(aoe_response)
else:
st.error("AOE model not loaded. Please reload the page.")
# Model information sidebar
with st.sidebar:
st.header("ℹ️ Model Information")
st.markdown("**πŸ€– GPT-4**")
st.write(f"Status: {'βœ… Configured' if openai_api_key else '❌ API key needed'}")
st.write("Provider: OpenAI")
st.markdown("**🌟 Gemini**")
st.write(f"Status: {'βœ… Configured' if google_api_key else '❌ API key needed'}")
st.write("Provider: Google")
st.markdown("**🏰 AOE (Local)**")
st.write(f"Status: {'βœ… Loaded' if st.session_state.aoe_loaded else '❌ Not loaded'}")
st.write("Path: outputs/student/")
if st.session_state.aoe_loaded:
try:
device_info = f"Device: {next(st.session_state.aoe_model.parameters()).device}"
st.write(device_info)
except:
pass
if st.button("πŸ”„ Reload AOE Model"):
st.session_state.aoe_loaded = False
st.experimental_rerun()
st.markdown("---")
st.markdown("**πŸ“‹ Instructions:**")
st.markdown("1. Configure API keys for GPT-4 and Gemini")
st.markdown("2. Enter your prompt in the text area")
st.markdown("3. Click 'Generate All Responses'")
st.markdown("4. Compare responses side by side")
st.markdown("---")
st.markdown("**⚠️ Notes:**")
st.markdown("- GPT-4 and Gemini require valid API keys")
st.markdown("- AOE model runs locally from outputs/student/")
st.markdown("- Responses are generated independently")
if __name__ == "__main__":
main()