Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import traceback | |
| from typing import Optional | |
| # Configure the page | |
| st.set_page_config( | |
| page_title="LLM Comparison: GPT-4 vs Gemini vs AOE", | |
| page_icon="βοΈ", | |
| layout="wide" | |
| ) | |
| def load_aoe_model(): | |
| """Load the AoE model and tokenizer from outputs/student/ directory""" | |
| model_path = "outputs/student/" | |
| try: | |
| if not os.path.exists(model_path): | |
| st.error(f"Model directory '{model_path}' not found. Please ensure the model files are present.") | |
| return None, None | |
| # Check if required files exist | |
| required_files = ["config.json", "pytorch_model.bin", "tokenizer.json"] | |
| missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))] | |
| if missing_files: | |
| st.warning(f"Some model files may be missing: {missing_files}. Attempting to load anyway...") | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True | |
| ) | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading AoE model: {str(e)}") | |
| st.text(f"Traceback: {traceback.format_exc()}") | |
| return None, None | |
| def generate_aoe_response(model, tokenizer, prompt, max_length=512): | |
| """Generate response from the AoE model""" | |
| try: | |
| # Tokenize input | |
| inputs = tokenizer.encode(prompt, return_tensors="pt") | |
| # Move to same device as model if CUDA is available | |
| if torch.cuda.is_available() and next(model.parameters()).is_cuda: | |
| inputs = inputs.cuda() | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs, | |
| max_length=len(inputs[0]) + max_length, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the input prompt from the response | |
| if response.startswith(prompt): | |
| response = response[len(prompt):].strip() | |
| return response | |
| except Exception as e: | |
| return f"Error generating AoE response: {str(e)}" | |
| def query_gpt4_api(prompt: str, api_key: Optional[str] = None) -> str: | |
| """Query GPT-4 API (placeholder - requires API key)""" | |
| if not api_key: | |
| return "β GPT-4 API key not configured. Please add your OpenAI API key to use GPT-4." | |
| try: | |
| # This is a placeholder implementation - would need actual OpenAI API integration | |
| return "π€ GPT-4 response would appear here with proper API configuration." | |
| except Exception as e: | |
| return f"Error querying GPT-4: {str(e)}" | |
| def query_gemini_api(prompt: str, api_key: Optional[str] = None) -> str: | |
| """Query Gemini API (placeholder - requires API key)""" | |
| if not api_key: | |
| return "β Gemini API key not configured. Please add your Google API key to use Gemini." | |
| try: | |
| # This is a placeholder implementation - would need actual Google Gemini API integration | |
| return "π€ Gemini response would appear here with proper API configuration." | |
| except Exception as e: | |
| return f"Error querying Gemini: {str(e)}" | |
| def main(): | |
| st.title("βοΈ LLM Comparison: GPT-4 vs Gemini vs AOE") | |
| st.markdown("Compare responses from three different language models side by side.") | |
| # Initialize session state for model caching | |
| if 'aoe_model' not in st.session_state: | |
| st.session_state.aoe_model = None | |
| st.session_state.aoe_tokenizer = None | |
| st.session_state.aoe_loaded = False | |
| # Load AOE model on first run | |
| if not st.session_state.aoe_loaded: | |
| with st.spinner("Loading AOE model from outputs/student/..."): | |
| model, tokenizer = load_aoe_model() | |
| if model is not None and tokenizer is not None: | |
| st.session_state.aoe_model = model | |
| st.session_state.aoe_tokenizer = tokenizer | |
| st.session_state.aoe_loaded = True | |
| st.success("β AOE model loaded successfully!") | |
| else: | |
| st.error("β Failed to load AOE model. Check error messages above.") | |
| # Configuration section | |
| st.markdown("---") | |
| st.subheader("π§ Configuration") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| openai_api_key = st.text_input( | |
| "OpenAI API Key (for GPT-4)", | |
| type="password", | |
| help="Enter your OpenAI API key to enable GPT-4 responses" | |
| ) | |
| with col2: | |
| google_api_key = st.text_input( | |
| "Google API Key (for Gemini)", | |
| type="password", | |
| help="Enter your Google API key to enable Gemini responses" | |
| ) | |
| with col3: | |
| max_length = st.slider( | |
| "Max Response Length", | |
| min_value=100, | |
| max_value=1000, | |
| value=512, | |
| step=50, | |
| help="Maximum length for generated responses" | |
| ) | |
| # Main comparison interface | |
| st.markdown("---") | |
| st.subheader("π¬ Compare LLM Responses") | |
| # User input | |
| user_prompt = st.text_area( | |
| "Enter your prompt:", | |
| placeholder="Type your prompt here to compare responses from all three models...", | |
| height=120, | |
| help="Enter a prompt to see how different LLMs respond" | |
| ) | |
| # Generate responses button | |
| if st.button("π Generate All Responses", type="primary"): | |
| if not user_prompt.strip(): | |
| st.warning("Please enter a prompt first.") | |
| else: | |
| # Create three columns for side-by-side comparison | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.markdown("### π€ GPT-4") | |
| with st.spinner("Generating GPT-4 response..."): | |
| gpt4_response = query_gpt4_api(user_prompt, openai_api_key) | |
| st.markdown("**Response:**") | |
| st.write(gpt4_response) | |
| with col2: | |
| st.markdown("### π Gemini") | |
| with st.spinner("Generating Gemini response..."): | |
| gemini_response = query_gemini_api(user_prompt, google_api_key) | |
| st.markdown("**Response:**") | |
| st.write(gemini_response) | |
| with col3: | |
| st.markdown("### π° AOE (Local)") | |
| if st.session_state.aoe_loaded: | |
| with st.spinner("Generating AOE response..."): | |
| aoe_response = generate_aoe_response( | |
| st.session_state.aoe_model, | |
| st.session_state.aoe_tokenizer, | |
| user_prompt, | |
| max_length | |
| ) | |
| st.markdown("**Response:**") | |
| st.write(aoe_response) | |
| else: | |
| st.error("AOE model not loaded. Please reload the page.") | |
| # Model information sidebar | |
| with st.sidebar: | |
| st.header("βΉοΈ Model Information") | |
| st.markdown("**π€ GPT-4**") | |
| st.write(f"Status: {'β Configured' if openai_api_key else 'β API key needed'}") | |
| st.write("Provider: OpenAI") | |
| st.markdown("**π Gemini**") | |
| st.write(f"Status: {'β Configured' if google_api_key else 'β API key needed'}") | |
| st.write("Provider: Google") | |
| st.markdown("**π° AOE (Local)**") | |
| st.write(f"Status: {'β Loaded' if st.session_state.aoe_loaded else 'β Not loaded'}") | |
| st.write("Path: outputs/student/") | |
| if st.session_state.aoe_loaded: | |
| try: | |
| device_info = f"Device: {next(st.session_state.aoe_model.parameters()).device}" | |
| st.write(device_info) | |
| except: | |
| pass | |
| if st.button("π Reload AOE Model"): | |
| st.session_state.aoe_loaded = False | |
| st.experimental_rerun() | |
| st.markdown("---") | |
| st.markdown("**π Instructions:**") | |
| st.markdown("1. Configure API keys for GPT-4 and Gemini") | |
| st.markdown("2. Enter your prompt in the text area") | |
| st.markdown("3. Click 'Generate All Responses'") | |
| st.markdown("4. Compare responses side by side") | |
| st.markdown("---") | |
| st.markdown("**β οΈ Notes:**") | |
| st.markdown("- GPT-4 and Gemini require valid API keys") | |
| st.markdown("- AOE model runs locally from outputs/student/") | |
| st.markdown("- Responses are generated independently") | |
| if __name__ == "__main__": | |
| main() |