ItCodinTime commited on
Commit
23d9eec
Β·
verified Β·
1 Parent(s): 95971ee

Add streamlit_app.py: AoE interactive model demo

Browse files
Files changed (1) hide show
  1. streamlit_app.py +160 -44
streamlit_app.py CHANGED
@@ -1,46 +1,162 @@
1
- <!doctype html>
2
- <html>
3
- <head>
4
- <title>Example Domain</title>
 
5
 
6
- <meta charset="utf-8" />
7
- <meta http-equiv="Content-type" content="text/html; charset=utf-8" />
8
- <meta name="viewport" content="width=device-width, initial-scale=1" />
9
- <style type="text/css">
10
- body {
11
- background-color: #f0f0f2;
12
- margin: 0;
13
- padding: 0;
14
- font-family: -apple-system, system-ui, BlinkMacSystemFont, "Segoe UI", "Open Sans", "Helvetica Neue", Helvetica, Arial, sans-serif;
15
-
16
- }
17
- div {
18
- width: 600px;
19
- margin: 5em auto;
20
- padding: 2em;
21
- background-color: #fdfdff;
22
- border-radius: 0.5em;
23
- box-shadow: 2px 3px 7px 2px rgba(0,0,0,0.02);
24
- }
25
- a:link, a:visited {
26
- color: #38488f;
27
- text-decoration: none;
28
- }
29
- @media (max-width: 700px) {
30
- div {
31
- margin: 0 auto;
32
- width: auto;
33
- }
34
- }
35
- </style>
36
- </head>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- <body>
39
- <div>
40
- <h1>Example Domain</h1>
41
- <p>This domain is for use in illustrative examples in documents. You may use this
42
- domain in literature without prior coordination or asking for permission.</p>
43
- <p><a href="https://www.iana.org/domains/example">More information...</a></p>
44
- </div>
45
- </body>
46
- </html>
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import traceback
6
 
7
+ def load_model():
8
+ """Load the AoE model and tokenizer from outputs/student/ directory"""
9
+ model_path = "outputs/student/"
10
+
11
+ try:
12
+ if not os.path.exists(model_path):
13
+ st.error(f"Model directory '{model_path}' not found. Please ensure the model files are present.")
14
+ return None, None
15
+
16
+ # Check if required files exist
17
+ required_files = ["config.json", "pytorch_model.bin", "tokenizer.json"]
18
+ missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
19
+
20
+ if missing_files:
21
+ st.warning(f"Some model files may be missing: {missing_files}. Attempting to load anyway...")
22
+
23
+ # Load tokenizer and model
24
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_path,
27
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
28
+ device_map="auto" if torch.cuda.is_available() else None,
29
+ trust_remote_code=True
30
+ )
31
+
32
+ return model, tokenizer
33
+
34
+ except Exception as e:
35
+ st.error(f"Error loading model: {str(e)}")
36
+ st.text(f"Traceback: {traceback.format_exc()}")
37
+ return None, None
38
+
39
+ def generate_response(model, tokenizer, prompt, max_length=512):
40
+ """Generate response from the model"""
41
+ try:
42
+ # Tokenize input
43
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
44
+
45
+ # Move to same device as model if CUDA is available
46
+ if torch.cuda.is_available() and next(model.parameters()).is_cuda:
47
+ inputs = inputs.cuda()
48
+
49
+ # Generate response
50
+ with torch.no_grad():
51
+ outputs = model.generate(
52
+ inputs,
53
+ max_length=len(inputs[0]) + max_length,
54
+ num_return_sequences=1,
55
+ temperature=0.7,
56
+ do_sample=True,
57
+ pad_token_id=tokenizer.eos_token_id
58
+ )
59
+
60
+ # Decode response
61
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+
63
+ # Remove the input prompt from the response
64
+ if response.startswith(prompt):
65
+ response = response[len(prompt):].strip()
66
+
67
+ return response
68
+
69
+ except Exception as e:
70
+ return f"Error generating response: {str(e)}"
71
+
72
+ def main():
73
+ st.title("🏰 AoE Model Chat Demo")
74
+ st.markdown("Interactive chat interface for the AoE (Attention over Experts) model.")
75
+
76
+ # Initialize session state for model caching
77
+ if 'model' not in st.session_state:
78
+ st.session_state.model = None
79
+ st.session_state.tokenizer = None
80
+ st.session_state.model_loaded = False
81
+
82
+ # Load model on first run or if not loaded
83
+ if not st.session_state.model_loaded:
84
+ with st.spinner("Loading AoE model from outputs/student/..."):
85
+ model, tokenizer = load_model()
86
+ if model is not None and tokenizer is not None:
87
+ st.session_state.model = model
88
+ st.session_state.tokenizer = tokenizer
89
+ st.session_state.model_loaded = True
90
+ st.success("βœ… Model loaded successfully!")
91
+ else:
92
+ st.error("❌ Failed to load model. Please check the error messages above.")
93
+ return
94
+
95
+ # Chat interface
96
+ st.markdown("---")
97
+ st.subheader("πŸ’¬ Chat with the Model")
98
+
99
+ # User input
100
+ user_prompt = st.text_area(
101
+ "Enter your prompt:",
102
+ placeholder="Type your message here...",
103
+ height=100,
104
+ help="Enter a prompt to chat with the AoE model"
105
+ )
106
+
107
+ # Generation parameters
108
+ col1, col2 = st.columns(2)
109
+ with col1:
110
+ max_length = st.slider("Max response length", 50, 1000, 512, 50)
111
+ with col2:
112
+ if st.button("πŸ”„ Reload Model", help="Reload the model if needed"):
113
+ st.session_state.model_loaded = False
114
+ st.experimental_rerun()
115
+
116
+ # Submit button
117
+ if st.button("πŸš€ Generate Response", type="primary"):
118
+ if not user_prompt.strip():
119
+ st.warning("Please enter a prompt first.")
120
+ elif st.session_state.model_loaded:
121
+ with st.spinner("Generating response..."):
122
+ response = generate_response(
123
+ st.session_state.model,
124
+ st.session_state.tokenizer,
125
+ user_prompt,
126
+ max_length
127
+ )
128
+
129
+ # Display response
130
+ st.markdown("---")
131
+ st.subheader("πŸ€– Model Response:")
132
+ st.write(response)
133
+ else:
134
+ st.error("Model not loaded. Please check the error messages above.")
135
+
136
+ # Model info sidebar
137
+ with st.sidebar:
138
+ st.header("ℹ️ Model Info")
139
+ st.write("**Model Path:** outputs/student/")
140
+ st.write(f"**Model Loaded:** {'βœ… Yes' if st.session_state.model_loaded else '❌ No'}")
141
+
142
+ if st.session_state.model_loaded:
143
+ try:
144
+ model_info = f"**Device:** {next(st.session_state.model.parameters()).device}"
145
+ st.write(model_info)
146
+ except:
147
+ pass
148
+
149
+ st.markdown("---")
150
+ st.markdown("**Instructions:**")
151
+ st.markdown("1. Enter your prompt in the text area")
152
+ st.markdown("2. Adjust max response length if needed")
153
+ st.markdown("3. Click 'Generate Response' to chat")
154
+
155
+ st.markdown("---")
156
+ st.markdown("**Troubleshooting:**")
157
+ st.markdown("- Ensure model files exist in outputs/student/")
158
+ st.markdown("- Required files: config.json, pytorch_model.bin, tokenizer files")
159
+ st.markdown("- Use 'Reload Model' if issues occur")
160
 
161
+ if __name__ == "__main__":
162
+ main()