ItCodinTime commited on
Commit
7cb9cc4
Β·
verified Β·
1 Parent(s): d649435

Replace spiral demo with LLM comparison interface

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +235 -32
src/streamlit_app.py CHANGED
@@ -1,40 +1,243 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
 
 
 
 
 
 
 
 
18
 
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
 
 
 
 
 
 
 
22
 
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import traceback
6
+ from typing import Optional
7
 
8
+ # Configure the page
9
+ st.set_page_config(
10
+ page_title="LLM Comparison: GPT-4 vs Gemini vs AOE",
11
+ page_icon="βš”οΈ",
12
+ layout="wide"
13
+ )
14
 
15
+ def load_aoe_model():
16
+ """Load the AoE model and tokenizer from outputs/student/ directory"""
17
+ model_path = "outputs/student/"
18
+
19
+ try:
20
+ if not os.path.exists(model_path):
21
+ st.error(f"Model directory '{model_path}' not found. Please ensure the model files are present.")
22
+ return None, None
23
+
24
+ # Check if required files exist
25
+ required_files = ["config.json", "pytorch_model.bin", "tokenizer.json"]
26
+ missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
27
+
28
+ if missing_files:
29
+ st.warning(f"Some model files may be missing: {missing_files}. Attempting to load anyway...")
30
+
31
+ # Load tokenizer and model
32
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_path,
35
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
36
+ device_map="auto" if torch.cuda.is_available() else None,
37
+ trust_remote_code=True
38
+ )
39
+
40
+ return model, tokenizer
41
+
42
+ except Exception as e:
43
+ st.error(f"Error loading AoE model: {str(e)}")
44
+ st.text(f"Traceback: {traceback.format_exc()}")
45
+ return None, None
46
 
47
+ def generate_aoe_response(model, tokenizer, prompt, max_length=512):
48
+ """Generate response from the AoE model"""
49
+ try:
50
+ # Tokenize input
51
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
52
+
53
+ # Move to same device as model if CUDA is available
54
+ if torch.cuda.is_available() and next(model.parameters()).is_cuda:
55
+ inputs = inputs.cuda()
56
+
57
+ # Generate response
58
+ with torch.no_grad():
59
+ outputs = model.generate(
60
+ inputs,
61
+ max_length=len(inputs[0]) + max_length,
62
+ num_return_sequences=1,
63
+ temperature=0.7,
64
+ do_sample=True,
65
+ pad_token_id=tokenizer.eos_token_id
66
+ )
67
+
68
+ # Decode response
69
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
+
71
+ # Remove the input prompt from the response
72
+ if response.startswith(prompt):
73
+ response = response[len(prompt):].strip()
74
+
75
+ return response
76
+
77
+ except Exception as e:
78
+ return f"Error generating AoE response: {str(e)}"
79
 
80
+ def query_gpt4_api(prompt: str, api_key: Optional[str] = None) -> str:
81
+ """Query GPT-4 API (placeholder - requires API key)"""
82
+ if not api_key:
83
+ return "❌ GPT-4 API key not configured. Please add your OpenAI API key to use GPT-4."
84
+
85
+ try:
86
+ # This is a placeholder implementation - would need actual OpenAI API integration
87
+ return "πŸ€– GPT-4 response would appear here with proper API configuration."
88
+ except Exception as e:
89
+ return f"Error querying GPT-4: {str(e)}"
90
 
91
+ def query_gemini_api(prompt: str, api_key: Optional[str] = None) -> str:
92
+ """Query Gemini API (placeholder - requires API key)"""
93
+ if not api_key:
94
+ return "❌ Gemini API key not configured. Please add your Google API key to use Gemini."
95
+
96
+ try:
97
+ # This is a placeholder implementation - would need actual Google Gemini API integration
98
+ return "πŸ€– Gemini response would appear here with proper API configuration."
99
+ except Exception as e:
100
+ return f"Error querying Gemini: {str(e)}"
101
 
102
+ def main():
103
+ st.title("βš”οΈ LLM Comparison: GPT-4 vs Gemini vs AOE")
104
+ st.markdown("Compare responses from three different language models side by side.")
105
+
106
+ # Initialize session state for model caching
107
+ if 'aoe_model' not in st.session_state:
108
+ st.session_state.aoe_model = None
109
+ st.session_state.aoe_tokenizer = None
110
+ st.session_state.aoe_loaded = False
111
+
112
+ # Load AOE model on first run
113
+ if not st.session_state.aoe_loaded:
114
+ with st.spinner("Loading AOE model from outputs/student/..."):
115
+ model, tokenizer = load_aoe_model()
116
+ if model is not None and tokenizer is not None:
117
+ st.session_state.aoe_model = model
118
+ st.session_state.aoe_tokenizer = tokenizer
119
+ st.session_state.aoe_loaded = True
120
+ st.success("βœ… AOE model loaded successfully!")
121
+ else:
122
+ st.error("❌ Failed to load AOE model. Check error messages above.")
123
+
124
+ # Configuration section
125
+ st.markdown("---")
126
+ st.subheader("πŸ”§ Configuration")
127
+
128
+ col1, col2, col3 = st.columns(3)
129
+
130
+ with col1:
131
+ openai_api_key = st.text_input(
132
+ "OpenAI API Key (for GPT-4)",
133
+ type="password",
134
+ help="Enter your OpenAI API key to enable GPT-4 responses"
135
+ )
136
+
137
+ with col2:
138
+ google_api_key = st.text_input(
139
+ "Google API Key (for Gemini)",
140
+ type="password",
141
+ help="Enter your Google API key to enable Gemini responses"
142
+ )
143
+
144
+ with col3:
145
+ max_length = st.slider(
146
+ "Max Response Length",
147
+ min_value=100,
148
+ max_value=1000,
149
+ value=512,
150
+ step=50,
151
+ help="Maximum length for generated responses"
152
+ )
153
+
154
+ # Main comparison interface
155
+ st.markdown("---")
156
+ st.subheader("πŸ’¬ Compare LLM Responses")
157
+
158
+ # User input
159
+ user_prompt = st.text_area(
160
+ "Enter your prompt:",
161
+ placeholder="Type your prompt here to compare responses from all three models...",
162
+ height=120,
163
+ help="Enter a prompt to see how different LLMs respond"
164
+ )
165
+
166
+ # Generate responses button
167
+ if st.button("πŸš€ Generate All Responses", type="primary"):
168
+ if not user_prompt.strip():
169
+ st.warning("Please enter a prompt first.")
170
+ else:
171
+ # Create three columns for side-by-side comparison
172
+ col1, col2, col3 = st.columns(3)
173
+
174
+ with col1:
175
+ st.markdown("### πŸ€– GPT-4")
176
+ with st.spinner("Generating GPT-4 response..."):
177
+ gpt4_response = query_gpt4_api(user_prompt, openai_api_key)
178
+ st.markdown("**Response:**")
179
+ st.write(gpt4_response)
180
+
181
+ with col2:
182
+ st.markdown("### 🌟 Gemini")
183
+ with st.spinner("Generating Gemini response..."):
184
+ gemini_response = query_gemini_api(user_prompt, google_api_key)
185
+ st.markdown("**Response:**")
186
+ st.write(gemini_response)
187
+
188
+ with col3:
189
+ st.markdown("### 🏰 AOE (Local)")
190
+ if st.session_state.aoe_loaded:
191
+ with st.spinner("Generating AOE response..."):
192
+ aoe_response = generate_aoe_response(
193
+ st.session_state.aoe_model,
194
+ st.session_state.aoe_tokenizer,
195
+ user_prompt,
196
+ max_length
197
+ )
198
+ st.markdown("**Response:**")
199
+ st.write(aoe_response)
200
+ else:
201
+ st.error("AOE model not loaded. Please reload the page.")
202
+
203
+ # Model information sidebar
204
+ with st.sidebar:
205
+ st.header("ℹ️ Model Information")
206
+
207
+ st.markdown("**πŸ€– GPT-4**")
208
+ st.write(f"Status: {'βœ… Configured' if openai_api_key else '❌ API key needed'}")
209
+ st.write("Provider: OpenAI")
210
+
211
+ st.markdown("**🌟 Gemini**")
212
+ st.write(f"Status: {'βœ… Configured' if google_api_key else '❌ API key needed'}")
213
+ st.write("Provider: Google")
214
+
215
+ st.markdown("**🏰 AOE (Local)**")
216
+ st.write(f"Status: {'βœ… Loaded' if st.session_state.aoe_loaded else '❌ Not loaded'}")
217
+ st.write("Path: outputs/student/")
218
+ if st.session_state.aoe_loaded:
219
+ try:
220
+ device_info = f"Device: {next(st.session_state.aoe_model.parameters()).device}"
221
+ st.write(device_info)
222
+ except:
223
+ pass
224
+
225
+ if st.button("πŸ”„ Reload AOE Model"):
226
+ st.session_state.aoe_loaded = False
227
+ st.experimental_rerun()
228
+
229
+ st.markdown("---")
230
+ st.markdown("**πŸ“‹ Instructions:**")
231
+ st.markdown("1. Configure API keys for GPT-4 and Gemini")
232
+ st.markdown("2. Enter your prompt in the text area")
233
+ st.markdown("3. Click 'Generate All Responses'")
234
+ st.markdown("4. Compare responses side by side")
235
+
236
+ st.markdown("---")
237
+ st.markdown("**⚠️ Notes:**")
238
+ st.markdown("- GPT-4 and Gemini require valid API keys")
239
+ st.markdown("- AOE model runs locally from outputs/student/")
240
+ st.markdown("- Responses are generated independently")
241
 
242
+ if __name__ == "__main__":
243
+ main()