Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,18 +8,24 @@ from huggingface_hub import HfApi
|
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
import time
|
| 10 |
from datetime import datetime
|
|
|
|
| 11 |
|
| 12 |
# Cyberpunk and Loading Animation Styling
|
| 13 |
def setup_cyberpunk_style():
|
| 14 |
st.markdown("""
|
| 15 |
<style>
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
.stApp {
|
| 20 |
background: radial-gradient(circle, rgba(0, 0, 0, 0.95) 20%, rgba(0, 50, 80, 0.95) 90%);
|
| 21 |
color: #00ff9d;
|
| 22 |
font-family: 'Orbitron', sans-serif;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
.main-title {
|
|
@@ -145,7 +151,9 @@ def initialize_model(model_name="gpt2"):
|
|
| 145 |
# Load Dataset Function with Uploaded File Option
|
| 146 |
def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
|
| 147 |
if data_source == "demo":
|
| 148 |
-
data = ["
|
|
|
|
|
|
|
| 149 |
elif uploaded_file is not None:
|
| 150 |
if uploaded_file.name.endswith(".txt"):
|
| 151 |
data = [uploaded_file.read().decode("utf-8")]
|
|
@@ -160,7 +168,7 @@ def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
|
|
| 160 |
return dataset
|
| 161 |
|
| 162 |
# Train Model Function with Customized Progress Bar
|
| 163 |
-
def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4):
|
| 164 |
training_args = TrainingArguments(
|
| 165 |
output_dir="./results",
|
| 166 |
overwrite_output_dir=True,
|
|
@@ -179,14 +187,26 @@ def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4):
|
|
| 179 |
args=training_args,
|
| 180 |
data_collator=data_collator,
|
| 181 |
train_dataset=train_dataset,
|
|
|
|
| 182 |
)
|
| 183 |
|
| 184 |
trainer.train()
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
# Main App Logic
|
| 187 |
def main():
|
| 188 |
setup_cyberpunk_style()
|
| 189 |
-
st.markdown('<h1 class="main-title">
|
| 190 |
|
| 191 |
# Initialize model and tokenizer
|
| 192 |
model, tokenizer = initialize_model()
|
|
@@ -225,6 +245,15 @@ def main():
|
|
| 225 |
# Load Dataset
|
| 226 |
train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
# Go Button to Start Training
|
| 229 |
if st.button("Go"):
|
| 230 |
progress_placeholder = st.empty()
|
|
@@ -233,22 +262,21 @@ def main():
|
|
| 233 |
|
| 234 |
dashboard = TrainingDashboard()
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
<div class="loading-animation"></div>
|
| 239 |
-
""", unsafe_allow_html=True)
|
| 240 |
-
|
| 241 |
-
train_model(model, train_dataset, tokenizer, epochs=1, batch_size=batch_size)
|
| 242 |
-
|
| 243 |
-
# Update Progress Bar
|
| 244 |
-
progress = (epoch + 1) / training_epochs * 100
|
| 245 |
progress_placeholder.markdown(f"""
|
| 246 |
<div class="progress-bar-container">
|
| 247 |
<div class="progress-bar" style="width: {progress}%;"></div>
|
| 248 |
</div>
|
| 249 |
""", unsafe_allow_html=True)
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
loading_animation.empty()
|
| 254 |
st.success("Training Complete!")
|
|
|
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
import time
|
| 10 |
from datetime import datetime
|
| 11 |
+
import threading
|
| 12 |
|
| 13 |
# Cyberpunk and Loading Animation Styling
|
| 14 |
def setup_cyberpunk_style():
|
| 15 |
st.markdown("""
|
| 16 |
<style>
|
| 17 |
+
body, button, input, select, textarea {
|
| 18 |
+
font-family: 'Orbitron', sans-serif !important;
|
| 19 |
+
color: #00ff9d !important;
|
| 20 |
+
}
|
| 21 |
.stApp {
|
| 22 |
background: radial-gradient(circle, rgba(0, 0, 0, 0.95) 20%, rgba(0, 50, 80, 0.95) 90%);
|
| 23 |
color: #00ff9d;
|
| 24 |
font-family: 'Orbitron', sans-serif;
|
| 25 |
+
font-size: 16px;
|
| 26 |
+
line-height: 1.6;
|
| 27 |
+
padding: 20px;
|
| 28 |
+
box-sizing: border-box;
|
| 29 |
}
|
| 30 |
|
| 31 |
.main-title {
|
|
|
|
| 151 |
# Load Dataset Function with Uploaded File Option
|
| 152 |
def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
|
| 153 |
if data_source == "demo":
|
| 154 |
+
data = ["In the neon-lit streets of Neo-Tokyo, a lone hacker fights against the oppressive megacorporations.",
|
| 155 |
+
"The rain falls in sheets, washing away the bloodstains from the alleyways.",
|
| 156 |
+
"She plugs into the matrix, seeking answers to questions that have haunted her for years."]
|
| 157 |
elif uploaded_file is not None:
|
| 158 |
if uploaded_file.name.endswith(".txt"):
|
| 159 |
data = [uploaded_file.read().decode("utf-8")]
|
|
|
|
| 168 |
return dataset
|
| 169 |
|
| 170 |
# Train Model Function with Customized Progress Bar
|
| 171 |
+
def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4, progress_callback=None):
|
| 172 |
training_args = TrainingArguments(
|
| 173 |
output_dir="./results",
|
| 174 |
overwrite_output_dir=True,
|
|
|
|
| 187 |
args=training_args,
|
| 188 |
data_collator=data_collator,
|
| 189 |
train_dataset=train_dataset,
|
| 190 |
+
callbacks=[ProgressCallback(progress_callback)]
|
| 191 |
)
|
| 192 |
|
| 193 |
trainer.train()
|
| 194 |
|
| 195 |
+
class ProgressCallback(TrainerCallback):
|
| 196 |
+
def __init__(self, progress_callback):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.progress_callback = progress_callback
|
| 199 |
+
|
| 200 |
+
def on_epoch_end(self, args, state, control, **kwargs):
|
| 201 |
+
loss = state.log_history[-1]['loss']
|
| 202 |
+
generation = state.global_step // args.gradient_accumulation_steps + 1
|
| 203 |
+
individual = args.gradient_accumulation_steps
|
| 204 |
+
self.progress_callback(loss, generation, individual)
|
| 205 |
+
|
| 206 |
# Main App Logic
|
| 207 |
def main():
|
| 208 |
setup_cyberpunk_style()
|
| 209 |
+
st.markdown('<h1 class="main-title">Neural Training Hub</h1>', unsafe_allow_html=True)
|
| 210 |
|
| 211 |
# Initialize model and tokenizer
|
| 212 |
model, tokenizer = initialize_model()
|
|
|
|
| 245 |
# Load Dataset
|
| 246 |
train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
|
| 247 |
|
| 248 |
+
# Chatbot Interaction
|
| 249 |
+
if st.checkbox("Enable Chatbot"):
|
| 250 |
+
user_input = st.text_input("You:", placeholder="Type your message here...")
|
| 251 |
+
if user_input:
|
| 252 |
+
inputs = tokenizer(user_input, return_tensors="pt")
|
| 253 |
+
outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1)
|
| 254 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 255 |
+
st.write("Bot:", response)
|
| 256 |
+
|
| 257 |
# Go Button to Start Training
|
| 258 |
if st.button("Go"):
|
| 259 |
progress_placeholder = st.empty()
|
|
|
|
| 262 |
|
| 263 |
dashboard = TrainingDashboard()
|
| 264 |
|
| 265 |
+
def train_progress(loss, generation, individual):
|
| 266 |
+
progress = (generation + 1) / dashboard.metrics['training_epochs'] * 100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
progress_placeholder.markdown(f"""
|
| 268 |
<div class="progress-bar-container">
|
| 269 |
<div class="progress-bar" style="width: {progress}%;"></div>
|
| 270 |
</div>
|
| 271 |
""", unsafe_allow_html=True)
|
| 272 |
+
dashboard.update(loss=loss, generation=generation, individual=individual)
|
| 273 |
+
|
| 274 |
+
thread = threading.Thread(target=train_model, args=(model, train_dataset, tokenizer, training_epochs, batch_size, train_progress))
|
| 275 |
+
thread.start()
|
| 276 |
+
loading_animation.markdown("""
|
| 277 |
+
<div class="loading-animation"></div>
|
| 278 |
+
""", unsafe_allow_html=True)
|
| 279 |
+
thread.join()
|
| 280 |
|
| 281 |
loading_animation.empty()
|
| 282 |
st.success("Training Complete!")
|