Rabe3 commited on
Commit
094e504
·
verified ·
1 Parent(s): 923692f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +242 -33
src/streamlit_app.py CHANGED
@@ -1,40 +1,249 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ import time
5
+ import os
6
 
7
+ # Page configuration
8
+ st.set_page_config(
9
+ page_title="Hakim AI Assistant",
10
+ page_icon="🤖",
11
+ layout="wide",
12
+ initial_sidebar_state="expanded"
13
+ )
14
 
15
+ # Custom CSS for better UI
16
+ st.markdown("""
17
+ <style>
18
+ .main-header {
19
+ text-align: center;
20
+ color: #2E86AB;
21
+ font-size: 2.5rem;
22
+ margin-bottom: 2rem;
23
+ }
24
+ .chat-message {
25
+ padding: 1rem;
26
+ border-radius: 10px;
27
+ margin: 1rem 0;
28
+ }
29
+ .user-message {
30
+ background-color: #E3F2FD;
31
+ border-left: 5px solid #2196F3;
32
+ }
33
+ .assistant-message {
34
+ background-color: #F1F8E9;
35
+ border-left: 5px solid #4CAF50;
36
+ }
37
+ .stTextArea textarea {
38
+ border-radius: 10px;
39
+ }
40
+ </style>
41
+ """, unsafe_allow_html=True)
42
 
43
+ @st.cache_resource
44
+ def load_model_and_tokenizer():
45
+ """Load the model and tokenizer with caching for better performance"""
46
+ try:
47
+ with st.spinner("Loading Hakim model... This may take a few minutes on first load."):
48
+ # Load tokenizer
49
+ tokenizer = AutoTokenizer.from_pretrained("Rabe3/Hakim")
50
+
51
+ # Load model with appropriate settings
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ "Rabe3/Hakim",
54
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
55
+ device_map="auto" if torch.cuda.is_available() else None,
56
+ trust_remote_code=True
57
+ )
58
+
59
+ # Create pipeline
60
+ text_pipeline = pipeline(
61
+ "text-generation",
62
+ model=model,
63
+ tokenizer=tokenizer,
64
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
65
+ device_map="auto" if torch.cuda.is_available() else None
66
+ )
67
+
68
+ return tokenizer, model, text_pipeline
69
+ except Exception as e:
70
+ st.error(f"Error loading model: {str(e)}")
71
+ return None, None, None
72
 
73
+ def generate_response(pipeline, prompt, system_prompt, max_length=512, temperature=0.7, top_p=0.9, do_sample=True):
74
+ """Generate response using the model pipeline"""
75
+ try:
76
+ # Combine system prompt with user input
77
+ full_prompt = f"{system_prompt}\n\nUser: {prompt}\nAssistant:"
78
+
79
+ # Generate response
80
+ with st.spinner("Generating response..."):
81
+ response = pipeline(
82
+ full_prompt,
83
+ max_length=max_length,
84
+ temperature=temperature,
85
+ top_p=top_p,
86
+ do_sample=do_sample,
87
+ pad_token_id=pipeline.tokenizer.eos_token_id,
88
+ return_full_text=False,
89
+ num_return_sequences=1
90
+ )
91
+
92
+ # Extract generated text
93
+ generated_text = response[0]['generated_text']
94
+
95
+ # Clean up the response (remove the prompt part if it's included)
96
+ if "Assistant:" in generated_text:
97
+ generated_text = generated_text.split("Assistant:")[-1].strip()
98
+
99
+ return generated_text
100
+
101
+ except Exception as e:
102
+ return f"Error generating response: {str(e)}"
103
 
104
+ def main():
105
+ # Header
106
+ st.markdown('<h1 class="main-header">🤖 Hakim AI Assistant</h1>', unsafe_allow_html=True)
107
+
108
+ # Load model
109
+ tokenizer, model, pipeline = load_model_and_tokenizer()
110
+
111
+ if pipeline is None:
112
+ st.error("Failed to load the model. Please refresh the page and try again.")
113
+ return
114
+
115
+ # Sidebar for configuration
116
+ with st.sidebar:
117
+ st.header("⚙️ Configuration")
118
+
119
+ # System prompt
120
+ system_prompt = st.text_area(
121
+ "System Prompt",
122
+ value="You are Hakim, a helpful AI assistant. You provide accurate, helpful, and informative responses. You communicate clearly and professionally.",
123
+ height=150,
124
+ help="This prompt sets the behavior and personality of the AI assistant."
125
+ )
126
+
127
+ st.divider()
128
+
129
+ # Generation parameters
130
+ st.subheader("Generation Parameters")
131
+
132
+ max_length = st.slider(
133
+ "Max Length",
134
+ min_value=50,
135
+ max_value=1000,
136
+ value=512,
137
+ step=50,
138
+ help="Maximum length of generated response"
139
+ )
140
+
141
+ temperature = st.slider(
142
+ "Temperature",
143
+ min_value=0.1,
144
+ max_value=2.0,
145
+ value=0.7,
146
+ step=0.1,
147
+ help="Controls randomness (lower = more focused, higher = more creative)"
148
+ )
149
+
150
+ top_p = st.slider(
151
+ "Top P",
152
+ min_value=0.1,
153
+ max_value=1.0,
154
+ value=0.9,
155
+ step=0.05,
156
+ help="Controls diversity via nucleus sampling"
157
+ )
158
+
159
+ do_sample = st.checkbox(
160
+ "Enable Sampling",
161
+ value=True,
162
+ help="Enable sampling for more diverse responses"
163
+ )
164
+
165
+ st.divider()
166
+
167
+ # Model info
168
+ st.subheader("ℹ️ Model Information")
169
+ st.info("**Model:** Rabe3/Hakim\n**Type:** Causal Language Model\n**Framework:** Transformers")
170
+
171
+ # Clear chat button
172
+ if st.button("🗑️ Clear Chat History", type="secondary"):
173
+ if 'messages' in st.session_state:
174
+ st.session_state.messages = []
175
+ st.rerun()
176
+
177
+ # Initialize chat history
178
+ if "messages" not in st.session_state:
179
+ st.session_state.messages = []
180
+
181
+ # Main chat interface
182
+ st.header("💬 Chat Interface")
183
+
184
+ # Display chat history
185
+ for message in st.session_state.messages:
186
+ if message["role"] == "user":
187
+ st.markdown(f'<div class="chat-message user-message"><strong>You:</strong> {message["content"]}</div>', unsafe_allow_html=True)
188
+ else:
189
+ st.markdown(f'<div class="chat-message assistant-message"><strong>Hakim:</strong> {message["content"]}</div>', unsafe_allow_html=True)
190
+
191
+ # Chat input
192
+ user_input = st.text_area(
193
+ "Enter your message:",
194
+ height=100,
195
+ placeholder="Type your message here...",
196
+ key="user_input"
197
+ )
198
+
199
+ col1, col2 = st.columns([1, 4])
200
+
201
+ with col1:
202
+ send_button = st.button("📤 Send", type="primary")
203
+
204
+ with col2:
205
+ if st.button("💡 Example Questions"):
206
+ examples = [
207
+ "What is artificial intelligence?",
208
+ "Can you help me write a short story?",
209
+ "Explain quantum computing in simple terms",
210
+ "What are the benefits of renewable energy?"
211
+ ]
212
+ st.session_state.user_input = st.selectbox("Choose an example:", [""] + examples)
213
+
214
+ # Process user input
215
+ if send_button and user_input.strip():
216
+ # Add user message to history
217
+ st.session_state.messages.append({"role": "user", "content": user_input})
218
+
219
+ # Generate response
220
+ response = generate_response(
221
+ pipeline=pipeline,
222
+ prompt=user_input,
223
+ system_prompt=system_prompt,
224
+ max_length=max_length,
225
+ temperature=temperature,
226
+ top_p=top_p,
227
+ do_sample=do_sample
228
+ )
229
+
230
+ # Add assistant response to history
231
+ st.session_state.messages.append({"role": "assistant", "content": response})
232
+
233
+ # Rerun to update the display
234
+ st.rerun()
235
+
236
+ # Footer
237
+ st.divider()
238
+ st.markdown(
239
+ """
240
+ <div style='text-align: center; color: #666; margin-top: 2rem;'>
241
+ <p>Powered by <strong>Rabe3/Hakim</strong> model from Hugging Face 🤗</p>
242
+ <p><em>This AI assistant is designed to be helpful, harmless, and honest.</em></p>
243
+ </div>
244
+ """,
245
+ unsafe_allow_html=True
246
+ )
247
 
248
+ if __name__ == "__main__":
249
+ main()