Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -73,13 +73,13 @@ status_manager = StatusManager()
|
|
| 73 |
# --- Model Loading ---
|
| 74 |
def initialize_model_background():
|
| 75 |
"""
|
| 76 |
-
Loads the base pre-trained language model (
|
| 77 |
in a background thread to keep the Gradio UI responsive.
|
| 78 |
"""
|
| 79 |
global model, tokenizer
|
| 80 |
|
| 81 |
try:
|
| 82 |
-
status_manager.update_status("π Loading base
|
| 83 |
|
| 84 |
# Clear CUDA cache if a GPU is available to free up memory before loading a new model
|
| 85 |
if torch.cuda.is_available():
|
|
@@ -87,8 +87,8 @@ def initialize_model_background():
|
|
| 87 |
|
| 88 |
status_manager.update_status("π Downloading model weights (this might take a while)...", 30)
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
model_name = "
|
| 92 |
|
| 93 |
# Load the tokenizer associated with the model
|
| 94 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -164,14 +164,15 @@ def prepare_model_for_training():
|
|
| 164 |
status_manager.update_status("β
Model already prepared for training", 100)
|
| 165 |
return "β
Model already prepared for training"
|
| 166 |
|
| 167 |
-
# Define LoRA configuration. Target modules are specific to
|
| 168 |
lora_config = LoraConfig(
|
| 169 |
task_type=TaskType.CAUSAL_LM,
|
| 170 |
r=8, # LoRA attention dimension (e.g., 8, 16, 32)
|
| 171 |
lora_alpha=16, # Alpha parameter for LoRA scaling
|
| 172 |
lora_dropout=0.1, # Dropout probability for LoRA layers
|
| 173 |
bias="none", # Bias type (none, all, lora_only)
|
| 174 |
-
|
|
|
|
| 175 |
)
|
| 176 |
|
| 177 |
# Apply LoRA to the base model, making only a small portion trainable
|
|
@@ -362,11 +363,11 @@ def train_model_background(batch_size, grad_accum, epochs, lr):
|
|
| 362 |
"""
|
| 363 |
global model, tokenizer, trainer, training_stats, train_dataset, eval_dataset
|
| 364 |
|
| 365 |
-
#
|
| 366 |
-
#
|
| 367 |
-
#
|
| 368 |
-
|
| 369 |
-
print("PyTorch anomaly detection is
|
| 370 |
|
| 371 |
try:
|
| 372 |
# Step 1: Ensure dataset is loaded and ready
|
|
@@ -689,4 +690,3 @@ def main():
|
|
| 689 |
|
| 690 |
if __name__ == "__main__":
|
| 691 |
main()
|
| 692 |
-
|
|
|
|
| 73 |
# --- Model Loading ---
|
| 74 |
def initialize_model_background():
|
| 75 |
"""
|
| 76 |
+
Loads the base pre-trained language model (distilgpt2) and its tokenizer
|
| 77 |
in a background thread to keep the Gradio UI responsive.
|
| 78 |
"""
|
| 79 |
global model, tokenizer
|
| 80 |
|
| 81 |
try:
|
| 82 |
+
status_manager.update_status("π Loading base distilgpt2 model...", 10)
|
| 83 |
|
| 84 |
# Clear CUDA cache if a GPU is available to free up memory before loading a new model
|
| 85 |
if torch.cuda.is_available():
|
|
|
|
| 87 |
|
| 88 |
status_manager.update_status("π Downloading model weights (this might take a while)...", 30)
|
| 89 |
|
| 90 |
+
# Changed model to distilgpt2 for lighter computation
|
| 91 |
+
model_name = "distilgpt2"
|
| 92 |
|
| 93 |
# Load the tokenizer associated with the model
|
| 94 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
| 164 |
status_manager.update_status("β
Model already prepared for training", 100)
|
| 165 |
return "β
Model already prepared for training"
|
| 166 |
|
| 167 |
+
# Define LoRA configuration. Target modules are specific to distilgpt2's architecture.
|
| 168 |
lora_config = LoraConfig(
|
| 169 |
task_type=TaskType.CAUSAL_LM,
|
| 170 |
r=8, # LoRA attention dimension (e.g., 8, 16, 32)
|
| 171 |
lora_alpha=16, # Alpha parameter for LoRA scaling
|
| 172 |
lora_dropout=0.1, # Dropout probability for LoRA layers
|
| 173 |
bias="none", # Bias type (none, all, lora_only)
|
| 174 |
+
# Adjusted target modules for distilgpt2
|
| 175 |
+
target_modules=["c_attn", "c_proj", "c_fc"],
|
| 176 |
)
|
| 177 |
|
| 178 |
# Apply LoRA to the base model, making only a small portion trainable
|
|
|
|
| 363 |
"""
|
| 364 |
global model, tokenizer, trainer, training_stats, train_dataset, eval_dataset
|
| 365 |
|
| 366 |
+
# Disable PyTorch anomaly detection for faster training.
|
| 367 |
+
# Re-enable if in-place modification errors persist with the new model.
|
| 368 |
+
# torch.autograd.set_detect_anomaly(True)
|
| 369 |
+
# print("PyTorch anomaly detection is ENABLED. Training may be slower but will provide detailed error traces.")
|
| 370 |
+
print("PyTorch anomaly detection is DISABLED for faster training.")
|
| 371 |
|
| 372 |
try:
|
| 373 |
# Step 1: Ensure dataset is loaded and ready
|
|
|
|
| 690 |
|
| 691 |
if __name__ == "__main__":
|
| 692 |
main()
|
|
|