msmaje commited on
Commit
e532a61
Β·
verified Β·
1 Parent(s): 715c9e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -73,13 +73,13 @@ status_manager = StatusManager()
73
  # --- Model Loading ---
74
  def initialize_model_background():
75
  """
76
- Loads the base pre-trained language model (DialoGPT-medium) and its tokenizer
77
  in a background thread to keep the Gradio UI responsive.
78
  """
79
  global model, tokenizer
80
 
81
  try:
82
- status_manager.update_status("πŸ”„ Loading base DialoGPT-medium model...", 10)
83
 
84
  # Clear CUDA cache if a GPU is available to free up memory before loading a new model
85
  if torch.cuda.is_available():
@@ -87,8 +87,8 @@ def initialize_model_background():
87
 
88
  status_manager.update_status("πŸ”„ Downloading model weights (this might take a while)...", 30)
89
 
90
- # Model name for Microsoft DialoGPT-medium, a good general-purpose conversational model
91
- model_name = "microsoft/DialoGPT-medium"
92
 
93
  # Load the tokenizer associated with the model
94
  tokenizer = AutoTokenizer.from_pretrained(
@@ -164,14 +164,15 @@ def prepare_model_for_training():
164
  status_manager.update_status("βœ… Model already prepared for training", 100)
165
  return "βœ… Model already prepared for training"
166
 
167
- # Define LoRA configuration. Target modules are specific to DialoGPT's architecture.
168
  lora_config = LoraConfig(
169
  task_type=TaskType.CAUSAL_LM,
170
  r=8, # LoRA attention dimension (e.g., 8, 16, 32)
171
  lora_alpha=16, # Alpha parameter for LoRA scaling
172
  lora_dropout=0.1, # Dropout probability for LoRA layers
173
  bias="none", # Bias type (none, all, lora_only)
174
- target_modules=["c_attn", "c_proj"], # Key attention and projection layers in DialoGPT
 
175
  )
176
 
177
  # Apply LoRA to the base model, making only a small portion trainable
@@ -362,11 +363,11 @@ def train_model_background(batch_size, grad_accum, epochs, lr):
362
  """
363
  global model, tokenizer, trainer, training_stats, train_dataset, eval_dataset
364
 
365
- # Enable PyTorch anomaly detection for debugging in-place operation errors
366
- # WARNING: This can significantly slow down training, use only for debugging.
367
- # It will provide a detailed traceback to pinpoint the exact problematic operation.
368
- torch.autograd.set_detect_anomaly(True)
369
- print("PyTorch anomaly detection is ENABLED. Training may be slower but will provide detailed error traces.")
370
 
371
  try:
372
  # Step 1: Ensure dataset is loaded and ready
@@ -689,4 +690,3 @@ def main():
689
 
690
  if __name__ == "__main__":
691
  main()
692
-
 
73
  # --- Model Loading ---
74
  def initialize_model_background():
75
  """
76
+ Loads the base pre-trained language model (distilgpt2) and its tokenizer
77
  in a background thread to keep the Gradio UI responsive.
78
  """
79
  global model, tokenizer
80
 
81
  try:
82
+ status_manager.update_status("πŸ”„ Loading base distilgpt2 model...", 10)
83
 
84
  # Clear CUDA cache if a GPU is available to free up memory before loading a new model
85
  if torch.cuda.is_available():
 
87
 
88
  status_manager.update_status("πŸ”„ Downloading model weights (this might take a while)...", 30)
89
 
90
+ # Changed model to distilgpt2 for lighter computation
91
+ model_name = "distilgpt2"
92
 
93
  # Load the tokenizer associated with the model
94
  tokenizer = AutoTokenizer.from_pretrained(
 
164
  status_manager.update_status("βœ… Model already prepared for training", 100)
165
  return "βœ… Model already prepared for training"
166
 
167
+ # Define LoRA configuration. Target modules are specific to distilgpt2's architecture.
168
  lora_config = LoraConfig(
169
  task_type=TaskType.CAUSAL_LM,
170
  r=8, # LoRA attention dimension (e.g., 8, 16, 32)
171
  lora_alpha=16, # Alpha parameter for LoRA scaling
172
  lora_dropout=0.1, # Dropout probability for LoRA layers
173
  bias="none", # Bias type (none, all, lora_only)
174
+ # Adjusted target modules for distilgpt2
175
+ target_modules=["c_attn", "c_proj", "c_fc"],
176
  )
177
 
178
  # Apply LoRA to the base model, making only a small portion trainable
 
363
  """
364
  global model, tokenizer, trainer, training_stats, train_dataset, eval_dataset
365
 
366
+ # Disable PyTorch anomaly detection for faster training.
367
+ # Re-enable if in-place modification errors persist with the new model.
368
+ # torch.autograd.set_detect_anomaly(True)
369
+ # print("PyTorch anomaly detection is ENABLED. Training may be slower but will provide detailed error traces.")
370
+ print("PyTorch anomaly detection is DISABLED for faster training.")
371
 
372
  try:
373
  # Step 1: Ensure dataset is loaded and ready
 
690
 
691
  if __name__ == "__main__":
692
  main()