aagamjtdev commited on
Commit
1636abd
Β·
1 Parent(s): 0c2088f

correction

Browse files
Files changed (1) hide show
  1. app.py +375 -52
app.py CHANGED
@@ -713,6 +713,290 @@
713
  # demo.launch()
714
 
715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716
  import gradio as gr
717
  import subprocess
718
  import os
@@ -735,6 +1019,7 @@ MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
735
  def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
736
  """
737
  Handles the Gradio submission and executes the training script using subprocess.
 
738
  """
739
 
740
  # 1. Setup: Create output directory if it doesn't exist
@@ -742,13 +1027,22 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
742
 
743
  # 2. File Handling: Use the temporary path of the uploaded file
744
  if dataset_file is None:
745
- return "❌ ERROR: Please upload a file.", None, gr.Button(visible=False)
 
746
 
747
- # Using .name (Corrected in previous steps)
748
- input_path = dataset_file.name
 
 
 
 
 
 
 
749
 
750
  if not input_path.lower().endswith(".json"):
751
- return "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None, gr.Button(visible=False)
 
752
 
753
  progress(0.1, desc="Starting LayoutLMv3 Training...")
754
 
@@ -767,6 +1061,7 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
767
  ]
768
 
769
  log_output += f"Executing command: {' '.join(command)}\n\n"
 
770
 
771
  try:
772
  # 4. Run the training script and capture output
@@ -783,40 +1078,73 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
783
  log_output += line
784
  # Print to console as well for debugging
785
  print(line, end='')
 
 
786
 
787
  process.stdout.close()
788
  return_code = process.wait()
789
 
790
  # 5. Check for successful completion
791
  if return_code == 0:
792
- log_output += "\nβœ… TRAINING COMPLETE! Model saved."
 
 
793
  print("\nβœ… TRAINING COMPLETE! Model saved.")
794
 
795
  # 6. Verify model file exists
796
  if os.path.exists(MODEL_FILE_PATH):
797
  file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
798
- log_output += f"\nπŸ“¦ Model file: {MODEL_FILE_PATH}"
799
  log_output += f"\nπŸ“Š Model size: {file_size:.2f} MB"
800
 
801
  print(f"\nβœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
802
 
803
- # Create a copy in the root directory with timestamp for uniqueness
 
 
 
804
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
805
- download_filename = f"layoutlmv3_trained_{timestamp}.pth"
 
 
806
 
807
  try:
808
- shutil.copy2(MODEL_FILE_PATH, download_filename)
809
- log_output += f"\nπŸ“‹ Download file created: {download_filename}"
810
- print(f"βœ… Created download file: {download_filename}")
 
 
 
 
 
 
 
 
 
 
 
811
  except Exception as e:
812
- log_output += f"\n⚠️ Could not create download file: {e}"
813
- download_filename = MODEL_FILE_PATH
 
 
814
 
815
- # Return the path and make download button visible
816
- log_output += f"\n\nπŸŽ‰ SUCCESS! Click the 'Download Model' button below to save your model."
817
- log_output += f"\n⚠️ IMPORTANT: Download NOW - file will be deleted when Space restarts!"
 
 
 
 
 
 
 
 
818
 
819
- return log_output, download_filename, gr.Button(visible=True)
 
 
 
820
  else:
821
  log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
822
  log_output += f"\nπŸ” Checking directory contents..."
@@ -828,42 +1156,28 @@ def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float,
828
  else:
829
  log_output += f"\n❌ Directory {MODEL_OUTPUT_DIR} does not exist!"
830
 
831
- return log_output, None, gr.Button(visible=False)
 
832
  else:
833
- log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
834
- return log_output, None, gr.Button(visible=False)
 
 
 
 
835
 
836
  except FileNotFoundError:
837
  error_msg = f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
838
  print(error_msg)
839
- return error_msg, None, gr.Button(visible=False)
 
840
  except Exception as e:
841
  error_msg = f"❌ An unexpected error occurred: {e}"
842
  print(error_msg)
843
  import traceback
844
  print(traceback.format_exc())
845
- return error_msg, None, gr.Button(visible=False)
846
-
847
-
848
- def download_model():
849
- """
850
- Returns the model file for download.
851
- """
852
- if os.path.exists(MODEL_FILE_PATH):
853
- return MODEL_FILE_PATH
854
- else:
855
- # Check for any .pth files in current directory
856
- pth_files = [f for f in os.listdir('.') if f.endswith('.pth')]
857
- if pth_files:
858
- return pth_files[0]
859
-
860
- # Check checkpoints directory
861
- if os.path.exists(MODEL_OUTPUT_DIR):
862
- pth_files = [os.path.join(MODEL_OUTPUT_DIR, f) for f in os.listdir(MODEL_OUTPUT_DIR) if f.endswith('.pth')]
863
- if pth_files:
864
- return pth_files[0]
865
-
866
- return None
867
 
868
 
869
  # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
@@ -877,6 +1191,7 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
877
  - **Download your model IMMEDIATELY** after training completes!
878
  - The model file is **temporary** and will be deleted when the Space restarts.
879
  - A download button will appear below once training is complete.
 
880
 
881
  **⏱️ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
882
  """
@@ -916,15 +1231,15 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
916
  train_button = gr.Button("πŸ”₯ Start Training", variant="primary", size="lg")
917
 
918
  with gr.Column(scale=2):
919
- gr.Markdown("### πŸ“Š Training Progress")
920
 
921
  log_output = gr.Textbox(
922
- label="Training Logs",
923
  lines=25,
924
  max_lines=30,
925
  autoscroll=True,
926
  show_copy_button=True,
927
- placeholder="Click 'Start Training' to begin...\n\nLogs will appear here in real-time."
928
  )
929
 
930
  gr.Markdown("### ⬇️ Download Trained Model")
@@ -942,7 +1257,7 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
942
 
943
  # File output for download
944
  model_download = gr.File(
945
- label="Your trained model will appear here",
946
  interactive=False,
947
  visible=True
948
  )
@@ -950,19 +1265,21 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
950
  gr.Markdown(
951
  """
952
  **πŸ“₯ Download Instructions:**
953
- 1. Wait for training to complete (βœ… appears in logs)
954
- 2. Click the **"Download Model"** button above
955
- 3. Save the `.pth` file to your local machine
956
- 4. **Do this immediately** - file is temporary!
 
957
 
958
  **πŸ”§ Troubleshooting:**
959
  - If download button doesn't appear, check the logs for errors
960
  - Try reducing epochs or batch size if timeout occurs
961
  - Ensure your JSON file is properly formatted
 
962
  """
963
  )
964
 
965
- # Define the training action
966
  train_button.click(
967
  fn=train_model,
968
  inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
@@ -988,6 +1305,12 @@ with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as de
988
  - Passages
989
 
990
  **Model Details:** LayoutLMv3-base + CRF layer for sequence labeling
 
 
 
 
 
 
991
  """
992
  )
993
 
 
713
  # demo.launch()
714
 
715
 
716
+ # import gradio as gr
717
+ # import subprocess
718
+ # import os
719
+ # import sys
720
+ # from datetime import datetime
721
+ # import shutil
722
+ #
723
+ # # FIX: Update the script name to the correct one you uploaded
724
+ # TRAINING_SCRIPT = "HF_LayoutLM_with_Passage.py"
725
+ #
726
+ # # --- CORRECTED MODEL PATH BASED ON YOUR SCRIPT ---
727
+ # MODEL_OUTPUT_DIR = "checkpoints"
728
+ # MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
729
+ # MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
730
+ #
731
+ #
732
+ # # ----------------------------------------------------------------
733
+ #
734
+ #
735
+ # def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
736
+ # """
737
+ # Handles the Gradio submission and executes the training script using subprocess.
738
+ # """
739
+ #
740
+ # # 1. Setup: Create output directory if it doesn't exist
741
+ # os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
742
+ #
743
+ # # 2. File Handling: Use the temporary path of the uploaded file
744
+ # if dataset_file is None:
745
+ # return "❌ ERROR: Please upload a file.", None, gr.Button(visible=False)
746
+ #
747
+ # # Using .name (Corrected in previous steps)
748
+ # input_path = dataset_file.name
749
+ #
750
+ # if not input_path.lower().endswith(".json"):
751
+ # return "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None, gr.Button(visible=False)
752
+ #
753
+ # progress(0.1, desc="Starting LayoutLMv3 Training...")
754
+ #
755
+ # log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
756
+ #
757
+ # # 3. Construct the subprocess command
758
+ # command = [
759
+ # sys.executable,
760
+ # TRAINING_SCRIPT,
761
+ # "--mode", "train",
762
+ # "--input", input_path,
763
+ # "--batch_size", str(batch_size),
764
+ # "--epochs", str(epochs),
765
+ # "--lr", str(lr),
766
+ # "--max_len", str(max_len)
767
+ # ]
768
+ #
769
+ # log_output += f"Executing command: {' '.join(command)}\n\n"
770
+ #
771
+ # try:
772
+ # # 4. Run the training script and capture output
773
+ # process = subprocess.Popen(
774
+ # command,
775
+ # stdout=subprocess.PIPE,
776
+ # stderr=subprocess.STDOUT,
777
+ # text=True,
778
+ # bufsize=1
779
+ # )
780
+ #
781
+ # # Stream logs in real-time
782
+ # for line in iter(process.stdout.readline, ""):
783
+ # log_output += line
784
+ # # Print to console as well for debugging
785
+ # print(line, end='')
786
+ #
787
+ # process.stdout.close()
788
+ # return_code = process.wait()
789
+ #
790
+ # # 5. Check for successful completion
791
+ # if return_code == 0:
792
+ # log_output += "\nβœ… TRAINING COMPLETE! Model saved."
793
+ # print("\nβœ… TRAINING COMPLETE! Model saved.")
794
+ #
795
+ # # 6. Verify model file exists
796
+ # if os.path.exists(MODEL_FILE_PATH):
797
+ # file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
798
+ # log_output += f"\nπŸ“¦ Model file: {MODEL_FILE_PATH}"
799
+ # log_output += f"\nπŸ“Š Model size: {file_size:.2f} MB"
800
+ #
801
+ # print(f"\nβœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
802
+ #
803
+ # # Create a copy in the root directory with timestamp for uniqueness
804
+ # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
805
+ # download_filename = f"layoutlmv3_trained_{timestamp}.pth"
806
+ #
807
+ # try:
808
+ # shutil.copy2(MODEL_FILE_PATH, download_filename)
809
+ # log_output += f"\nπŸ“‹ Download file created: {download_filename}"
810
+ # print(f"βœ… Created download file: {download_filename}")
811
+ # except Exception as e:
812
+ # log_output += f"\n⚠️ Could not create download file: {e}"
813
+ # download_filename = MODEL_FILE_PATH
814
+ #
815
+ # # Return the path and make download button visible
816
+ # log_output += f"\n\nπŸŽ‰ SUCCESS! Click the 'Download Model' button below to save your model."
817
+ # log_output += f"\n⚠️ IMPORTANT: Download NOW - file will be deleted when Space restarts!"
818
+ #
819
+ # return log_output, download_filename, gr.Button(visible=True)
820
+ # else:
821
+ # log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
822
+ # log_output += f"\nπŸ” Checking directory contents..."
823
+ #
824
+ # # List files in checkpoints directory for debugging
825
+ # if os.path.exists(MODEL_OUTPUT_DIR):
826
+ # files = os.listdir(MODEL_OUTPUT_DIR)
827
+ # log_output += f"\nπŸ“ Files in {MODEL_OUTPUT_DIR}: {files}"
828
+ # else:
829
+ # log_output += f"\n❌ Directory {MODEL_OUTPUT_DIR} does not exist!"
830
+ #
831
+ # return log_output, None, gr.Button(visible=False)
832
+ # else:
833
+ # log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
834
+ # return log_output, None, gr.Button(visible=False)
835
+ #
836
+ # except FileNotFoundError:
837
+ # error_msg = f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
838
+ # print(error_msg)
839
+ # return error_msg, None, gr.Button(visible=False)
840
+ # except Exception as e:
841
+ # error_msg = f"❌ An unexpected error occurred: {e}"
842
+ # print(error_msg)
843
+ # import traceback
844
+ # print(traceback.format_exc())
845
+ # return error_msg, None, gr.Button(visible=False)
846
+ #
847
+ #
848
+ # def download_model():
849
+ # """
850
+ # Returns the model file for download.
851
+ # """
852
+ # if os.path.exists(MODEL_FILE_PATH):
853
+ # return MODEL_FILE_PATH
854
+ # else:
855
+ # # Check for any .pth files in current directory
856
+ # pth_files = [f for f in os.listdir('.') if f.endswith('.pth')]
857
+ # if pth_files:
858
+ # return pth_files[0]
859
+ #
860
+ # # Check checkpoints directory
861
+ # if os.path.exists(MODEL_OUTPUT_DIR):
862
+ # pth_files = [os.path.join(MODEL_OUTPUT_DIR, f) for f in os.listdir(MODEL_OUTPUT_DIR) if f.endswith('.pth')]
863
+ # if pth_files:
864
+ # return pth_files[0]
865
+ #
866
+ # return None
867
+ #
868
+ #
869
+ # # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
870
+ # with gr.Blocks(title="LayoutLMv3 Fine-Tuning App", theme=gr.themes.Soft()) as demo:
871
+ # gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
872
+ # gr.Markdown(
873
+ # """
874
+ # Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model.
875
+ #
876
+ # **⚠️ IMPORTANT - Free Tier Users:**
877
+ # - **Download your model IMMEDIATELY** after training completes!
878
+ # - The model file is **temporary** and will be deleted when the Space restarts.
879
+ # - A download button will appear below once training is complete.
880
+ #
881
+ # **⏱️ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
882
+ # """
883
+ # )
884
+ #
885
+ # with gr.Row():
886
+ # with gr.Column(scale=1):
887
+ # gr.Markdown("### πŸ“ Dataset Upload")
888
+ # file_input = gr.File(
889
+ # label="Upload Label Studio JSON Dataset",
890
+ # file_types=[".json"]
891
+ # )
892
+ #
893
+ # gr.Markdown("---")
894
+ # gr.Markdown("### βš™οΈ Training Parameters")
895
+ #
896
+ # batch_size_input = gr.Slider(
897
+ # minimum=1, maximum=16, step=1, value=4,
898
+ # label="Batch Size",
899
+ # info="Smaller = less memory, slower training"
900
+ # )
901
+ # epochs_input = gr.Slider(
902
+ # minimum=1, maximum=10, step=1, value=3,
903
+ # label="Epochs",
904
+ # info="Fewer epochs = faster training (recommended: 3-5)"
905
+ # )
906
+ # lr_input = gr.Number(
907
+ # value=5e-5, label="Learning Rate",
908
+ # info="Default: 5e-5"
909
+ # )
910
+ # max_len_input = gr.Slider(
911
+ # minimum=128, maximum=512, step=128, value=512,
912
+ # label="Max Sequence Length",
913
+ # info="Shorter = faster training, less memory"
914
+ # )
915
+ #
916
+ # train_button = gr.Button("πŸ”₯ Start Training", variant="primary", size="lg")
917
+ #
918
+ # with gr.Column(scale=2):
919
+ # gr.Markdown("### πŸ“Š Training Progress")
920
+ #
921
+ # log_output = gr.Textbox(
922
+ # label="Training Logs",
923
+ # lines=25,
924
+ # max_lines=30,
925
+ # autoscroll=True,
926
+ # show_copy_button=True,
927
+ # placeholder="Click 'Start Training' to begin...\n\nLogs will appear here in real-time."
928
+ # )
929
+ #
930
+ # gr.Markdown("### ⬇️ Download Trained Model")
931
+ #
932
+ # # Hidden state to store the file path
933
+ # model_path_state = gr.State(value=None)
934
+ #
935
+ # # Download button (initially hidden)
936
+ # download_btn = gr.Button(
937
+ # "πŸ“₯ Download Model (.pth file)",
938
+ # variant="primary",
939
+ # size="lg",
940
+ # visible=False
941
+ # )
942
+ #
943
+ # # File output for download
944
+ # model_download = gr.File(
945
+ # label="Your trained model will appear here",
946
+ # interactive=False,
947
+ # visible=True
948
+ # )
949
+ #
950
+ # gr.Markdown(
951
+ # """
952
+ # **πŸ“₯ Download Instructions:**
953
+ # 1. Wait for training to complete (βœ… appears in logs)
954
+ # 2. Click the **"Download Model"** button above
955
+ # 3. Save the `.pth` file to your local machine
956
+ # 4. **Do this immediately** - file is temporary!
957
+ #
958
+ # **πŸ”§ Troubleshooting:**
959
+ # - If download button doesn't appear, check the logs for errors
960
+ # - Try reducing epochs or batch size if timeout occurs
961
+ # - Ensure your JSON file is properly formatted
962
+ # """
963
+ # )
964
+ #
965
+ # # Define the training action
966
+ # train_button.click(
967
+ # fn=train_model,
968
+ # inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
969
+ # outputs=[log_output, model_path_state, download_btn],
970
+ # api_name="train"
971
+ # )
972
+ #
973
+ # # Define the download action
974
+ # download_btn.click(
975
+ # fn=lambda path: path,
976
+ # inputs=[model_path_state],
977
+ # outputs=[model_download]
978
+ # )
979
+ #
980
+ # # Add example info
981
+ # gr.Markdown(
982
+ # """
983
+ # ---
984
+ # ### πŸ“– About
985
+ # This Space fine-tunes LayoutLMv3 with CRF for document understanding tasks including:
986
+ # - Questions, Options, Answers
987
+ # - Section Headings
988
+ # - Passages
989
+ #
990
+ # **Model Details:** LayoutLMv3-base + CRF layer for sequence labeling
991
+ # """
992
+ # )
993
+ #
994
+ # if __name__ == "__main__":
995
+ # demo.launch()
996
+
997
+
998
+
999
+
1000
  import gradio as gr
1001
  import subprocess
1002
  import os
 
1019
  def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
1020
  """
1021
  Handles the Gradio submission and executes the training script using subprocess.
1022
+ Yields logs in real-time for user feedback.
1023
  """
1024
 
1025
  # 1. Setup: Create output directory if it doesn't exist
 
1027
 
1028
  # 2. File Handling: Use the temporary path of the uploaded file
1029
  if dataset_file is None:
1030
+ yield "❌ ERROR: Please upload a file.", None, gr.Button(visible=False)
1031
+ return
1032
 
1033
+ # CRITICAL FIX: dataset_file is a gradio.File object, use .name to get the path
1034
+ # This is a temporary file path like /tmp/gradio/.../filename.json
1035
+ input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)
1036
+
1037
+ # Verify the file actually exists before proceeding
1038
+ if not os.path.exists(input_path):
1039
+ error_msg = f"❌ ERROR: Uploaded file not found at {input_path}. Please try uploading again."
1040
+ yield error_msg, None, gr.Button(visible=False)
1041
+ return
1042
 
1043
  if not input_path.lower().endswith(".json"):
1044
+ yield "❌ ERROR: Please upload a valid Label Studio JSON file (.json).", None, gr.Button(visible=False)
1045
+ return
1046
 
1047
  progress(0.1, desc="Starting LayoutLMv3 Training...")
1048
 
 
1061
  ]
1062
 
1063
  log_output += f"Executing command: {' '.join(command)}\n\n"
1064
+ yield log_output, None, gr.Button(visible=False) # Initial yield
1065
 
1066
  try:
1067
  # 4. Run the training script and capture output
 
1078
  log_output += line
1079
  # Print to console as well for debugging
1080
  print(line, end='')
1081
+ # Yield updated logs in real-time
1082
+ yield log_output, None, gr.Button(visible=False)
1083
 
1084
  process.stdout.close()
1085
  return_code = process.wait()
1086
 
1087
  # 5. Check for successful completion
1088
  if return_code == 0:
1089
+ log_output += "\n" + "=" * 60 + "\n"
1090
+ log_output += "βœ… TRAINING COMPLETE! Model saved successfully.\n"
1091
+ log_output += "=" * 60 + "\n"
1092
  print("\nβœ… TRAINING COMPLETE! Model saved.")
1093
 
1094
  # 6. Verify model file exists
1095
  if os.path.exists(MODEL_FILE_PATH):
1096
  file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
1097
+ log_output += f"\nπŸ“¦ Model file found: {MODEL_FILE_PATH}"
1098
  log_output += f"\nπŸ“Š Model size: {file_size:.2f} MB"
1099
 
1100
  print(f"\nβœ… Model exists at: {MODEL_FILE_PATH} ({file_size:.2f} MB)")
1101
 
1102
+ # CRITICAL: Copy to a simple location that Gradio can reliably serve
1103
+ # Use the same temp directory pattern as the uploaded JSON file
1104
+ import tempfile
1105
+ temp_dir = tempfile.gettempdir()
1106
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1107
+
1108
+ # Create filename in temp directory
1109
+ temp_model_path = os.path.join(temp_dir, f"layoutlmv3_trained_{timestamp}.pth")
1110
 
1111
  try:
1112
+ # Copy the model to temp directory
1113
+ shutil.copy2(MODEL_FILE_PATH, temp_model_path)
1114
+ log_output += f"\nπŸ“‹ Model copied to temporary download location"
1115
+ log_output += f"\nπŸ”— Download path: {temp_model_path}"
1116
+ print(f"βœ… Model copied to temp location: {temp_model_path}")
1117
+
1118
+ # Verify the copy exists
1119
+ if os.path.exists(temp_model_path):
1120
+ log_output += f"\nβœ… Download file verified and ready!"
1121
+ download_path = temp_model_path
1122
+ else:
1123
+ log_output += f"\n⚠️ Warning: Temp copy verification failed, using original path"
1124
+ download_path = MODEL_FILE_PATH
1125
+
1126
  except Exception as e:
1127
+ log_output += f"\n⚠️ Could not create temp copy: {e}"
1128
+ log_output += f"\nπŸ“ Using original path: {MODEL_FILE_PATH}"
1129
+ print(f"⚠️ Copy failed: {e}, using original path")
1130
+ download_path = MODEL_FILE_PATH
1131
 
1132
+ # Final success message
1133
+ log_output += f"\n\n{'=' * 60}"
1134
+ log_output += f"\nπŸŽ‰ SUCCESS! Your model is ready for download."
1135
+ log_output += f"\n{'=' * 60}"
1136
+ log_output += f"\n\n⬇️ Click the 'πŸ“₯ Download Model' button below to save your model."
1137
+ log_output += f"\n⚠️ CRITICAL: Download NOW! File will be deleted when:"
1138
+ log_output += f"\n - This tab is closed"
1139
+ log_output += f"\n - Space restarts or goes idle"
1140
+ log_output += f"\n - System clears temp files"
1141
+ log_output += f"\n\nπŸ“₯ The file will download as a .pth file to your computer's Downloads folder."
1142
+ log_output += f"\n\n{'=' * 60}\n"
1143
 
1144
+ # Return final logs and make download button visible
1145
+ # IMPORTANT: Return the path that Gradio can access
1146
+ yield log_output, download_path, gr.Button(visible=True)
1147
+ return
1148
  else:
1149
  log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
1150
  log_output += f"\nπŸ” Checking directory contents..."
 
1156
  else:
1157
  log_output += f"\n❌ Directory {MODEL_OUTPUT_DIR} does not exist!"
1158
 
1159
+ yield log_output, None, gr.Button(visible=False)
1160
+ return
1161
  else:
1162
+ log_output += f"\n\n{'=' * 60}\n"
1163
+ log_output += f"❌ TRAINING FAILED with return code {return_code}\n"
1164
+ log_output += f"{'=' * 60}\n"
1165
+ log_output += f"\nPlease check the logs above for error details.\n"
1166
+ yield log_output, None, gr.Button(visible=False)
1167
+ return
1168
 
1169
  except FileNotFoundError:
1170
  error_msg = f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space."
1171
  print(error_msg)
1172
+ yield log_output + "\n" + error_msg, None, gr.Button(visible=False)
1173
+ return
1174
  except Exception as e:
1175
  error_msg = f"❌ An unexpected error occurred: {e}"
1176
  print(error_msg)
1177
  import traceback
1178
  print(traceback.format_exc())
1179
+ yield log_output + "\n" + error_msg, None, gr.Button(visible=False)
1180
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1181
 
1182
 
1183
  # --- Gradio Interface Setup (using Blocks for a nicer layout) ---
 
1191
  - **Download your model IMMEDIATELY** after training completes!
1192
  - The model file is **temporary** and will be deleted when the Space restarts.
1193
  - A download button will appear below once training is complete.
1194
+ - **Real-time logs** will stream during training so you can monitor progress.
1195
 
1196
  **⏱️ Timeout Note:** Training may timeout on free tier. Consider reducing epochs or batch size for faster training.
1197
  """
 
1231
  train_button = gr.Button("πŸ”₯ Start Training", variant="primary", size="lg")
1232
 
1233
  with gr.Column(scale=2):
1234
+ gr.Markdown("### πŸ“Š Training Progress (Real-Time Logs)")
1235
 
1236
  log_output = gr.Textbox(
1237
+ label="Training Logs - Updates in Real-Time",
1238
  lines=25,
1239
  max_lines=30,
1240
  autoscroll=True,
1241
  show_copy_button=True,
1242
+ placeholder="Click 'Start Training' to begin...\n\nLogs will stream here in real-time as training progresses."
1243
  )
1244
 
1245
  gr.Markdown("### ⬇️ Download Trained Model")
 
1257
 
1258
  # File output for download
1259
  model_download = gr.File(
1260
+ label="Your trained model will appear here after clicking Download",
1261
  interactive=False,
1262
  visible=True
1263
  )
 
1265
  gr.Markdown(
1266
  """
1267
  **πŸ“₯ Download Instructions:**
1268
+ 1. Wait for training to complete - watch the real-time logs above
1269
+ 2. Look for **"βœ… TRAINING COMPLETE!"** message
1270
+ 3. Click the **"πŸ“₯ Download Model"** button that appears above
1271
+ 4. Save the `.pth` file to your local machine
1272
+ 5. **Do this immediately** - file is temporary and will be deleted on Space restart!
1273
 
1274
  **πŸ”§ Troubleshooting:**
1275
  - If download button doesn't appear, check the logs for errors
1276
  - Try reducing epochs or batch size if timeout occurs
1277
  - Ensure your JSON file is properly formatted
1278
+ - Logs update in real-time - you can monitor training progress
1279
  """
1280
  )
1281
 
1282
+ # Define the training action - now with real-time log streaming via yield
1283
  train_button.click(
1284
  fn=train_model,
1285
  inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
 
1305
  - Passages
1306
 
1307
  **Model Details:** LayoutLMv3-base + CRF layer for sequence labeling
1308
+
1309
+ **Features:**
1310
+ - βœ… Real-time log streaming during training
1311
+ - βœ… Progress monitoring with epoch/batch updates
1312
+ - βœ… Immediate model download after completion
1313
+ - βœ… Automatic file preparation for download
1314
  """
1315
  )
1316