ItCodinTime commited on
Commit
3aa34bf
Β·
verified Β·
1 Parent(s): 23d9eec

Implement LLM comparison interface for GPT-4, Gemini, and AOE models

Browse files
Files changed (1) hide show
  1. streamlit_app.py +149 -66
streamlit_app.py CHANGED
@@ -3,8 +3,18 @@ import torch
3
  import os
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import traceback
 
 
 
6
 
7
- def load_model():
 
 
 
 
 
 
 
8
  """Load the AoE model and tokenizer from outputs/student/ directory"""
9
  model_path = "outputs/student/"
10
 
@@ -32,12 +42,12 @@ def load_model():
32
  return model, tokenizer
33
 
34
  except Exception as e:
35
- st.error(f"Error loading model: {str(e)}")
36
  st.text(f"Traceback: {traceback.format_exc()}")
37
  return None, None
38
 
39
- def generate_response(model, tokenizer, prompt, max_length=512):
40
- """Generate response from the model"""
41
  try:
42
  # Tokenize input
43
  inputs = tokenizer.encode(prompt, return_tensors="pt")
@@ -67,96 +77,169 @@ def generate_response(model, tokenizer, prompt, max_length=512):
67
  return response
68
 
69
  except Exception as e:
70
- return f"Error generating response: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def main():
73
- st.title("🏰 AoE Model Chat Demo")
74
- st.markdown("Interactive chat interface for the AoE (Attention over Experts) model.")
75
 
76
  # Initialize session state for model caching
77
- if 'model' not in st.session_state:
78
- st.session_state.model = None
79
- st.session_state.tokenizer = None
80
- st.session_state.model_loaded = False
81
 
82
- # Load model on first run or if not loaded
83
- if not st.session_state.model_loaded:
84
- with st.spinner("Loading AoE model from outputs/student/..."):
85
- model, tokenizer = load_model()
86
  if model is not None and tokenizer is not None:
87
- st.session_state.model = model
88
- st.session_state.tokenizer = tokenizer
89
- st.session_state.model_loaded = True
90
- st.success("βœ… Model loaded successfully!")
91
  else:
92
- st.error("❌ Failed to load model. Please check the error messages above.")
93
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # Chat interface
96
  st.markdown("---")
97
- st.subheader("πŸ’¬ Chat with the Model")
98
 
99
  # User input
100
  user_prompt = st.text_area(
101
  "Enter your prompt:",
102
- placeholder="Type your message here...",
103
- height=100,
104
- help="Enter a prompt to chat with the AoE model"
105
  )
106
 
107
- # Generation parameters
108
- col1, col2 = st.columns(2)
109
- with col1:
110
- max_length = st.slider("Max response length", 50, 1000, 512, 50)
111
- with col2:
112
- if st.button("πŸ”„ Reload Model", help="Reload the model if needed"):
113
- st.session_state.model_loaded = False
114
- st.experimental_rerun()
115
-
116
- # Submit button
117
- if st.button("πŸš€ Generate Response", type="primary"):
118
  if not user_prompt.strip():
119
  st.warning("Please enter a prompt first.")
120
- elif st.session_state.model_loaded:
121
- with st.spinner("Generating response..."):
122
- response = generate_response(
123
- st.session_state.model,
124
- st.session_state.tokenizer,
125
- user_prompt,
126
- max_length
127
- )
128
-
129
- # Display response
130
- st.markdown("---")
131
- st.subheader("πŸ€– Model Response:")
132
- st.write(response)
133
  else:
134
- st.error("Model not loaded. Please check the error messages above.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- # Model info sidebar
137
  with st.sidebar:
138
- st.header("ℹ️ Model Info")
139
- st.write("**Model Path:** outputs/student/")
140
- st.write(f"**Model Loaded:** {'βœ… Yes' if st.session_state.model_loaded else '❌ No'}")
141
 
142
- if st.session_state.model_loaded:
 
 
 
 
 
 
 
 
 
 
 
143
  try:
144
- model_info = f"**Device:** {next(st.session_state.model.parameters()).device}"
145
- st.write(model_info)
146
  except:
147
  pass
148
 
 
 
 
 
149
  st.markdown("---")
150
- st.markdown("**Instructions:**")
151
- st.markdown("1. Enter your prompt in the text area")
152
- st.markdown("2. Adjust max response length if needed")
153
- st.markdown("3. Click 'Generate Response' to chat")
 
154
 
155
  st.markdown("---")
156
- st.markdown("**Troubleshooting:**")
157
- st.markdown("- Ensure model files exist in outputs/student/")
158
- st.markdown("- Required files: config.json, pytorch_model.bin, tokenizer files")
159
- st.markdown("- Use 'Reload Model' if issues occur")
160
 
161
  if __name__ == "__main__":
162
  main()
 
3
  import os
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import traceback
6
+ import requests
7
+ import json
8
+ from typing import Dict, Optional
9
 
10
+ # Configure the page
11
+ st.set_page_config(
12
+ page_title="LLM Comparison: GPT-4 vs Gemini vs AOE",
13
+ page_icon="βš”οΈ",
14
+ layout="wide"
15
+ )
16
+
17
+ def load_aoe_model():
18
  """Load the AoE model and tokenizer from outputs/student/ directory"""
19
  model_path = "outputs/student/"
20
 
 
42
  return model, tokenizer
43
 
44
  except Exception as e:
45
+ st.error(f"Error loading AoE model: {str(e)}")
46
  st.text(f"Traceback: {traceback.format_exc()}")
47
  return None, None
48
 
49
+ def generate_aoe_response(model, tokenizer, prompt, max_length=512):
50
+ """Generate response from the AoE model"""
51
  try:
52
  # Tokenize input
53
  inputs = tokenizer.encode(prompt, return_tensors="pt")
 
77
  return response
78
 
79
  except Exception as e:
80
+ return f"Error generating AoE response: {str(e)}"
81
+
82
+ def query_gpt4_api(prompt: str, api_key: Optional[str] = None) -> str:
83
+ """Query GPT-4 API (placeholder - requires API key)"""
84
+ if not api_key:
85
+ return "❌ GPT-4 API key not configured. Please add your OpenAI API key to use GPT-4."
86
+
87
+ try:
88
+ # This is a placeholder implementation - would need actual OpenAI API integration
89
+ return "πŸ€– GPT-4 response would appear here with proper API configuration."
90
+ except Exception as e:
91
+ return f"Error querying GPT-4: {str(e)}"
92
+
93
+ def query_gemini_api(prompt: str, api_key: Optional[str] = None) -> str:
94
+ """Query Gemini API (placeholder - requires API key)"""
95
+ if not api_key:
96
+ return "❌ Gemini API key not configured. Please add your Google API key to use Gemini."
97
+
98
+ try:
99
+ # This is a placeholder implementation - would need actual Google Gemini API integration
100
+ return "πŸ€– Gemini response would appear here with proper API configuration."
101
+ except Exception as e:
102
+ return f"Error querying Gemini: {str(e)}"
103
 
104
  def main():
105
+ st.title("βš”οΈ LLM Comparison: GPT-4 vs Gemini vs AOE")
106
+ st.markdown("Compare responses from three different language models side by side.")
107
 
108
  # Initialize session state for model caching
109
+ if 'aoe_model' not in st.session_state:
110
+ st.session_state.aoe_model = None
111
+ st.session_state.aoe_tokenizer = None
112
+ st.session_state.aoe_loaded = False
113
 
114
+ # Load AOE model on first run
115
+ if not st.session_state.aoe_loaded:
116
+ with st.spinner("Loading AOE model from outputs/student/..."):
117
+ model, tokenizer = load_aoe_model()
118
  if model is not None and tokenizer is not None:
119
+ st.session_state.aoe_model = model
120
+ st.session_state.aoe_tokenizer = tokenizer
121
+ st.session_state.aoe_loaded = True
122
+ st.success("βœ… AOE model loaded successfully!")
123
  else:
124
+ st.error("❌ Failed to load AOE model. Check error messages above.")
125
+
126
+ # Configuration section
127
+ st.markdown("---")
128
+ st.subheader("πŸ”§ Configuration")
129
+
130
+ col1, col2, col3 = st.columns(3)
131
+
132
+ with col1:
133
+ openai_api_key = st.text_input(
134
+ "OpenAI API Key (for GPT-4)",
135
+ type="password",
136
+ help="Enter your OpenAI API key to enable GPT-4 responses"
137
+ )
138
+
139
+ with col2:
140
+ google_api_key = st.text_input(
141
+ "Google API Key (for Gemini)",
142
+ type="password",
143
+ help="Enter your Google API key to enable Gemini responses"
144
+ )
145
+
146
+ with col3:
147
+ max_length = st.slider(
148
+ "Max Response Length",
149
+ min_value=100,
150
+ max_value=1000,
151
+ value=512,
152
+ step=50,
153
+ help="Maximum length for generated responses"
154
+ )
155
 
156
+ # Main comparison interface
157
  st.markdown("---")
158
+ st.subheader("πŸ’¬ Compare LLM Responses")
159
 
160
  # User input
161
  user_prompt = st.text_area(
162
  "Enter your prompt:",
163
+ placeholder="Type your prompt here to compare responses from all three models...",
164
+ height=120,
165
+ help="Enter a prompt to see how different LLMs respond"
166
  )
167
 
168
+ # Generate responses button
169
+ if st.button("πŸš€ Generate All Responses", type="primary"):
 
 
 
 
 
 
 
 
 
170
  if not user_prompt.strip():
171
  st.warning("Please enter a prompt first.")
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  else:
173
+ # Create three columns for side-by-side comparison
174
+ col1, col2, col3 = st.columns(3)
175
+
176
+ with col1:
177
+ st.markdown("### πŸ€– GPT-4")
178
+ with st.spinner("Generating GPT-4 response..."):
179
+ gpt4_response = query_gpt4_api(user_prompt, openai_api_key)
180
+ st.markdown("**Response:**")
181
+ st.write(gpt4_response)
182
+
183
+ with col2:
184
+ st.markdown("### 🌟 Gemini")
185
+ with st.spinner("Generating Gemini response..."):
186
+ gemini_response = query_gemini_api(user_prompt, google_api_key)
187
+ st.markdown("**Response:**")
188
+ st.write(gemini_response)
189
+
190
+ with col3:
191
+ st.markdown("### 🏰 AOE (Local)")
192
+ if st.session_state.aoe_loaded:
193
+ with st.spinner("Generating AOE response..."):
194
+ aoe_response = generate_aoe_response(
195
+ st.session_state.aoe_model,
196
+ st.session_state.aoe_tokenizer,
197
+ user_prompt,
198
+ max_length
199
+ )
200
+ st.markdown("**Response:**")
201
+ st.write(aoe_response)
202
+ else:
203
+ st.error("AOE model not loaded. Please reload the page.")
204
 
205
+ # Model information sidebar
206
  with st.sidebar:
207
+ st.header("ℹ️ Model Information")
 
 
208
 
209
+ st.markdown("**πŸ€– GPT-4**")
210
+ st.write(f"Status: {'βœ… Configured' if openai_api_key else '❌ API key needed'}")
211
+ st.write("Provider: OpenAI")
212
+
213
+ st.markdown("**🌟 Gemini**")
214
+ st.write(f"Status: {'βœ… Configured' if google_api_key else '❌ API key needed'}")
215
+ st.write("Provider: Google")
216
+
217
+ st.markdown("**🏰 AOE (Local)**")
218
+ st.write(f"Status: {'βœ… Loaded' if st.session_state.aoe_loaded else '❌ Not loaded'}")
219
+ st.write("Path: outputs/student/")
220
+ if st.session_state.aoe_loaded:
221
  try:
222
+ device_info = f"Device: {next(st.session_state.aoe_model.parameters()).device}"
223
+ st.write(device_info)
224
  except:
225
  pass
226
 
227
+ if st.button("πŸ”„ Reload AOE Model"):
228
+ st.session_state.aoe_loaded = False
229
+ st.experimental_rerun()
230
+
231
  st.markdown("---")
232
+ st.markdown("**πŸ“‹ Instructions:**")
233
+ st.markdown("1. Configure API keys for GPT-4 and Gemini")
234
+ st.markdown("2. Enter your prompt in the text area")
235
+ st.markdown("3. Click 'Generate All Responses'")
236
+ st.markdown("4. Compare responses side by side")
237
 
238
  st.markdown("---")
239
+ st.markdown("**⚠️ Notes:**")
240
+ st.markdown("- GPT-4 and Gemini require valid API keys")
241
+ st.markdown("- AOE model runs locally from outputs/student/")
242
+ st.markdown("- Responses are generated independently")
243
 
244
  if __name__ == "__main__":
245
  main()