Nihal2000 commited on
Commit
8d369f6
·
verified ·
1 Parent(s): dd61416

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -162
app.py CHANGED
@@ -1,192 +1,197 @@
1
  import streamlit as st
2
  import os
3
- import time
4
- from model_manager import ModelManager
5
- from inference_engine import InferenceEngine
6
  import torch
 
 
 
 
7
 
8
- # Page configuration
9
  st.set_page_config(
10
- page_title="Automotive SLM Chatbot",
11
  page_icon="🚗",
12
- layout="wide",
13
- initial_sidebar_state="expanded"
14
  )
15
 
16
- # Custom CSS
17
- st.markdown("""
18
- <style>
19
- .main-header {
20
- font-size: 2.5rem;
21
- color: #1f77b4;
22
- text-align: center;
23
- margin-bottom: 2rem;
24
- }
25
- .chat-message {
26
- padding: 1rem;
27
- border-radius: 0.5rem;
28
- margin: 0.5rem 0;
29
- }
30
- .user-message {
31
- background-color: #e3f2fd;
32
- border-left: 4px solid #1976d2;
33
- }
34
- .assistant-message {
35
- background-color: #f3e5f5;
36
- border-left: 4px solid #7b1fa2;
37
- }
38
- .model-info {
39
- background-color: #f5f5f5;
40
- padding: 1rem;
41
- border-radius: 0.5rem;
42
- border: 1px solid #ddd;
43
- }
44
- </style>
45
- """, unsafe_allow_html=True)
46
 
47
- @st.cache_resource
48
- def load_model_manager():
49
- """Cache the model manager to avoid reloading"""
50
- return ModelManager("assets")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- def initialize_session_state():
53
- """Initialize session state variables"""
54
- if "messages" not in st.session_state:
55
- st.session_state.messages = []
56
- if "current_model" not in st.session_state:
57
- st.session_state.current_model = None
58
- if "inference_engine" not in st.session_state:
59
- st.session_state.inference_engine = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- def display_chat_message(role, content, model_info=None):
62
- """Display a chat message with proper styling"""
63
- if role == "user":
64
- st.markdown(f"""
65
- <div class="chat-message user-message">
66
- <strong>You:</strong> {content}
67
- </div>
68
- """, unsafe_allow_html=True)
69
- else:
70
- model_text = f" ({model_info})" if model_info else ""
71
- st.markdown(f"""
72
- <div class="chat-message assistant-message">
73
- <strong>Assistant{model_text}:</strong> {content}
74
- </div>
75
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def main():
78
- # Initialize session state
79
- initialize_session_state()
 
80
 
81
- # Header
82
- st.markdown('<h1 class="main-header">🚗 Automotive SLM Chatbot</h1>', unsafe_allow_html=True)
83
 
84
- # Load model manager
85
- model_manager = load_model_manager()
86
 
87
- # Sidebar for model selection and settings
88
  with st.sidebar:
89
- st.header("⚙️ Model Settings")
90
-
91
- # Model selection
92
- available_models = model_manager.get_available_models()
93
- if available_models:
94
- selected_model = st.selectbox(
95
- "Select Model:",
96
- available_models,
97
- index=0 if st.session_state.current_model is None else available_models.index(st.session_state.current_model) if st.session_state.current_model in available_models else 0
98
- )
99
-
100
- # Load model if changed
101
- if selected_model != st.session_state.current_model:
102
- with st.spinner(f"Loading {selected_model}..."):
103
- model, tokenizer, config = model_manager.load_model(selected_model)
104
- st.session_state.inference_engine = InferenceEngine(model, tokenizer, config)
105
- st.session_state.current_model = selected_model
106
- st.success(f"Model {selected_model} loaded successfully!")
107
- else:
108
- st.error("No models found in assets folder!")
109
- st.stop()
110
-
111
- # Model information
112
- if st.session_state.inference_engine:
113
- st.subheader("📊 Model Info")
114
- model_info = model_manager.get_model_info(selected_model)
115
- st.markdown(f"""
116
- <div class="model-info">
117
- <strong>Model:</strong> {model_info['name']}<br>
118
- <strong>Type:</strong> {model_info['type']}<br>
119
- <strong>Parameters:</strong> {model_info['parameters']}<br>
120
- <strong>Size:</strong> {model_info['size']}
121
- </div>
122
- """, unsafe_allow_html=True)
123
 
124
- # Generation settings
125
- st.subheader("🎛️ Generation Settings")
126
- max_tokens = st.slider("Max Tokens", 10, 200, 50)
127
- temperature = st.slider("Temperature", 0.1, 2.0, 0.8, 0.1)
128
- top_p = st.slider("Top P", 0.1, 1.0, 0.9, 0.05)
129
- top_k = st.slider("Top K", 1, 100, 50)
130
-
131
- # Clear chat button
132
- if st.button("🗑️ Clear Chat"):
133
- st.session_state.messages = []
134
- st.rerun()
135
 
136
- # Main chat interface
137
- if st.session_state.inference_engine is None:
138
- st.info("Please select a model from the sidebar to start chatting.")
139
- return
 
140
 
141
  # Display chat history
142
- chat_container = st.container()
143
- with chat_container:
144
- for message in st.session_state.messages:
145
- display_chat_message(
146
- message["role"],
147
- message["content"],
148
- message.get("model", None)
149
- )
150
 
151
  # Chat input
152
- prompt = st.chat_input("Ask me about automotive topics...")
153
-
154
- if prompt:
155
  # Add user message
156
  st.session_state.messages.append({"role": "user", "content": prompt})
 
 
157
 
158
- # Display user message
159
- with chat_container:
160
- display_chat_message("user", prompt)
 
 
 
161
 
162
- # Generate response
163
- with st.spinner("Generating response..."):
164
- try:
165
- response = st.session_state.inference_engine.generate_response(
166
- prompt,
167
- max_tokens=max_tokens,
168
- temperature=temperature,
169
- top_p=top_p,
170
- top_k=top_k
171
- )
172
-
173
- # Add assistant message
174
- st.session_state.messages.append({
175
- "role": "assistant",
176
- "content": response,
177
- "model": selected_model
178
- })
179
-
180
- # Display assistant message
181
- with chat_container:
182
- display_chat_message("assistant", response, selected_model)
183
-
184
- except Exception as e:
185
- st.error(f"Error generating response: {str(e)}")
186
-
187
- # Footer
188
- st.markdown("---")
189
- st.markdown("*Powered by Automotive SLM - Specialized for automotive assistance*")
190
 
191
  if __name__ == "__main__":
192
  main()
 
1
  import streamlit as st
2
  import os
 
 
 
3
  import torch
4
+ import warnings
5
+ import logging
6
+ from transformers import AutoTokenizer
7
+ import gc
8
 
9
+ # HF Spaces specific configuration
10
  st.set_page_config(
11
+ page_title="🚗 Automotive SLM Assistant",
12
  page_icon="🚗",
13
+ layout="wide"
 
14
  )
15
 
16
+ # Suppress warnings for HF Spaces
17
+ warnings.filterwarnings('ignore')
18
+ logging.getLogger('streamlit').setLevel(logging.ERROR)
19
+ logging.getLogger('transformers').setLevel(logging.ERROR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # HF Spaces optimized model loading
22
+ @st.cache_resource(show_spinner="🚀 Loading your Automotive AI Assistant...")
23
+ def load_model_for_hf_spaces():
24
+ """Optimized model loading for HF Spaces environment"""
25
+ try:
26
+ # Force CPU usage for HF Spaces
27
+ device = torch.device('cpu')
28
+
29
+ # Load tokenizer first
30
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
31
+ if tokenizer.pad_token is None:
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+
34
+ # Simplified model loading for HF Spaces
35
+ # You would upload your model files to the HF Spaces repo
36
+ if os.path.exists("model.pt"):
37
+ checkpoint = torch.load("model.pt", map_location=device)
38
+
39
+ # Create simple config if not in checkpoint
40
+ config = {
41
+ 'd_model': 256,
42
+ 'n_layer': 4,
43
+ 'n_head': 4,
44
+ 'vocab_size': 50257,
45
+ 'n_positions': 256,
46
+ 'use_moe': True,
47
+ 'n_experts': 4
48
+ }
49
+
50
+ # Use simplified model class for HF Spaces
51
+ model = SimpleAutomotiveModel(config)
52
+
53
+ if 'model_state_dict' in checkpoint:
54
+ model.load_state_dict(checkpoint['model_state_dict'])
55
+
56
+ model.eval()
57
+
58
+ return model, tokenizer, config
59
+ else:
60
+ st.error("Model file not found. Please upload your model.pt to the repository.")
61
+ return None, None, None
62
+
63
+ except Exception as e:
64
+ st.error(f"Error loading model: {e}")
65
+ return None, None, None
66
 
67
+ # Simplified model class for HF Spaces
68
+ class SimpleAutomotiveModel(torch.nn.Module):
69
+ """Simplified model for HF Spaces deployment"""
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.config = config
73
+ self.embeddings = torch.nn.Embedding(config['vocab_size'], config['d_model'])
74
+ self.layers = torch.nn.ModuleList([
75
+ torch.nn.TransformerEncoderLayer(
76
+ d_model=config['d_model'],
77
+ nhead=config['n_head'],
78
+ batch_first=True
79
+ ) for _ in range(config['n_layer'])
80
+ ])
81
+ self.ln_f = torch.nn.LayerNorm(config['d_model'])
82
+ self.lm_head = torch.nn.Linear(config['d_model'], config['vocab_size'], bias=False)
83
+
84
+ def forward(self, input_ids):
85
+ x = self.embeddings(input_ids)
86
+ for layer in self.layers:
87
+ x = layer(x)
88
+ x = self.ln_f(x)
89
+ return {"logits": self.lm_head(x)}
90
+
91
+ def generate(self, input_ids, max_new_tokens=50, temperature=0.8, **kwargs):
92
+ """Simple generation for HF Spaces"""
93
+ device = input_ids.device
94
+ generated = input_ids.clone()
95
+
96
+ for _ in range(max_new_tokens):
97
+ with torch.no_grad():
98
+ outputs = self.forward(generated)
99
+ logits = outputs["logits"][:, -1, :] / temperature
100
+ probs = torch.softmax(logits, dim=-1)
101
+ next_token = torch.multinomial(probs, 1)
102
+ generated = torch.cat([generated, next_token], dim=1)
103
+
104
+ # Simple stopping condition
105
+ if next_token.item() == 50256: # EOS token
106
+ break
107
+
108
+ return generated
109
 
110
+ def generate_response(model, tokenizer, prompt, max_tokens=50, temperature=0.8):
111
+ """Generate response optimized for HF Spaces"""
112
+ try:
113
+ # Tokenize
114
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=200, truncation=True)
115
+
116
+ # Generate
117
+ with torch.no_grad():
118
+ outputs = model.generate(
119
+ inputs['input_ids'],
120
+ max_new_tokens=max_tokens,
121
+ temperature=temperature,
122
+ pad_token_id=tokenizer.pad_token_id
123
+ )
124
+
125
+ # Decode
126
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
127
+
128
+ # Remove original prompt
129
+ if response.startswith(prompt):
130
+ response = response[len(prompt):].strip()
131
+
132
+ return response if response else "I apologize, but I couldn't generate a proper response. Please try rephrasing your question."
133
+
134
+ except Exception as e:
135
+ return f"I encountered an error: {str(e)}. Please try again."
136
 
137
  def main():
138
+ # Title and description
139
+ st.title("🚗 Automotive SLM Assistant")
140
+ st.markdown("*Specialized AI assistant for automotive questions and troubleshooting*")
141
 
142
+ # Load model
143
+ model, tokenizer, config = load_model_for_hf_spaces()
144
 
145
+ if model is None:
146
+ st.stop()
147
 
148
+ # Sidebar settings
149
  with st.sidebar:
150
+ st.header("⚙️ Settings")
151
+ max_tokens = st.slider("Response Length", 20, 100, 50)
152
+ temperature = st.slider("Creativity", 0.3, 1.5, 0.8, 0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ st.markdown("---")
155
+ st.markdown("### 🎯 Example Questions")
156
+ st.markdown("""
157
+ - How do I check tire pressure?
158
+ - What does the check engine light mean?
159
+ - How to jump start a car?
160
+ - Electric vehicle charging tips
161
+ - Brake maintenance schedule
162
+ """)
 
 
163
 
164
+ # Initialize chat history
165
+ if "messages" not in st.session_state:
166
+ st.session_state.messages = [
167
+ {"role": "assistant", "content": "Hello! I'm your Automotive AI Assistant. Ask me anything about cars, maintenance, troubleshooting, or automotive technology!"}
168
+ ]
169
 
170
  # Display chat history
171
+ for message in st.session_state.messages:
172
+ with st.chat_message(message["role"]):
173
+ st.markdown(message["content"])
 
 
 
 
 
174
 
175
  # Chat input
176
+ if prompt := st.chat_input("Ask me about automotive topics..."):
 
 
177
  # Add user message
178
  st.session_state.messages.append({"role": "user", "content": prompt})
179
+ with st.chat_message("user"):
180
+ st.markdown(prompt)
181
 
182
+ # Generate and display response
183
+ with st.chat_message("assistant"):
184
+ with st.spinner("🤔 Thinking..."):
185
+ response = generate_response(model, tokenizer, prompt, max_tokens, temperature)
186
+ st.markdown(response)
187
+ st.session_state.messages.append({"role": "assistant", "content": response})
188
 
189
+ # Cleanup for HF Spaces memory management
190
+ if len(st.session_state.messages) > 20: # Keep last 20 messages
191
+ st.session_state.messages = st.session_state.messages[-20:]
192
+
193
+ # Force garbage collection
194
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  if __name__ == "__main__":
197
  main()