msmaje commited on
Commit
3c14fdc
ยท
verified ยท
1 Parent(s): 8466760

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -27
app.py CHANGED
@@ -231,7 +231,7 @@ def compute_metrics(eval_pred):
231
  }
232
 
233
  def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
234
- learning_rate, hf_token, push_to_hub, username, model_name):
235
  """Train the model using inline training (no subprocess)"""
236
  global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
237
 
@@ -642,22 +642,123 @@ def get_available_datasets():
642
  available_files = ["No CSV files found in current directory"]
643
 
644
  return available_files
 
645
  def display_available_datasets():
646
- datasets = get_available_datasets()
647
- if datasets:
648
- return "**Available CSV files:**\n\n" + "\n".join([f"- {file}" for file in datasets])
649
- else:
650
- return "No CSV files found in the current directory."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
 
652
- # Initialize the display
653
- refresh_datasets_btn.click(
654
- display_available_datasets,
655
- outputs=available_datasets
656
- )
 
 
 
 
 
657
 
658
- # Show datasets on load
659
- app.load(display_available_datasets, outputs=available_datasets)
660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  gr.Markdown("### Dataset Format Requirements")
662
  gr.Markdown("""
663
  **For training, your CSV file should have:**
@@ -676,7 +777,6 @@ def display_available_datasets():
676
  "Poor TV signal",TV-Radio
677
  ```
678
  """)
679
-
680
  gr.Markdown("### Model Categories")
681
  categories_info = f"""
682
  **The model classifies complaints into these categories:**
@@ -689,23 +789,60 @@ def display_available_datasets():
689
  """
690
  gr.Markdown(categories_info)
691
 
692
- # Launch the app
693
- if __name__ == "__main__":
694
- # Initialize tokenizer on startup
695
- if CURRENT_TOKENIZER is None:
696
- try:
697
- CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
698
- print("โœ… Tokenizer initialized successfully")
699
- except Exception as e:
700
- print(f"โš ๏ธ Warning: Could not initialize tokenizer: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
- print("๐Ÿš€ Launching BERT Complaint Classifier...")
703
- print("๐Ÿ“ Available at: http://localhost:7860")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
 
 
 
 
 
 
705
  app.launch(
706
  server_name="0.0.0.0",
707
  server_port=7860,
708
  share=False,
709
  show_error=True
710
- )
711
-
 
231
  }
232
 
233
  def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
234
+ learning_rate, hf_token, push_to_hub, username, model_name):
235
  """Train the model using inline training (no subprocess)"""
236
  global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
237
 
 
642
  available_files = ["No CSV files found in current directory"]
643
 
644
  return available_files
645
+
646
  def display_available_datasets():
647
+ datasets = get_available_datasets()
648
+ if datasets:
649
+ return "**Available CSV files:**\n\n" + "\n".join([f"- {file}" for file in datasets])
650
+ else:
651
+ return "No CSV files found in the current directory."
652
+
653
+ # Initialize tokenizer on startup
654
+ if CURRENT_TOKENIZER is None:
655
+ try:
656
+ CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
657
+ print("โœ… Tokenizer initialized successfully")
658
+ except Exception as e:
659
+ print(f"โš ๏ธ Warning: Could not initialize tokenizer: {e}")
660
+
661
+ print("๐Ÿš€ Launching BERT Complaint Classifier...")
662
+ print("๐Ÿ“ Available at: http://localhost:7860")
663
+
664
+ # The entire Gradio UI definition must be within a single block
665
+ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
666
+ gr.Markdown("# BERT Complaint Classifier ๐Ÿ—ฃ๏ธ๐Ÿค–")
667
+ gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
668
+
669
+ with gr.Tab("Fine-tune Model"):
670
+ gr.Markdown("## ๐Ÿ‹๏ธ Fine-tune a New Model")
671
+
672
+ with gr.Column(variant="panel"):
673
+ gr.Markdown("### ๐Ÿ› ๏ธ Training Configuration")
674
+ with gr.Row():
675
+ uploaded_file = gr.File(label="Upload Training CSV File", type="filepath", file_types=["csv"])
676
+ preview_btn = gr.Button("Preview Dataset")
677
+ preview_output = gr.Markdown("Dataset info will appear here")
678
 
679
+ with gr.Row():
680
+ text_column_input = gr.Textbox(label="Text Column Name", value="complaint")
681
+ label_column_input = gr.Textbox(label="Label Column Name", value="category")
682
+
683
+ gr.Markdown("---")
684
+
685
+ with gr.Row():
686
+ num_epochs_slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Epochs")
687
+ batch_size_slider = gr.Slider(minimum=4, maximum=32, step=4, value=8, label="Batch Size")
688
+ learning_rate_slider = gr.Slider(minimum=1e-6, maximum=1e-4, step=1e-6, value=2e-5, label="Learning Rate")
689
 
690
+ gr.Markdown("---")
 
691
 
692
+ gr.Markdown("### โ˜๏ธ Hugging Face Hub (Optional)")
693
+ with gr.Row():
694
+ push_to_hub_checkbox = gr.Checkbox(label="Push to Hugging Face Hub")
695
+ hf_token_input = gr.Textbox(label="Hugging Face Token", type="password")
696
+ with gr.Row():
697
+ hf_username_input = gr.Textbox(label="Hugging Face Username")
698
+ hf_model_name_input = gr.Textbox(label="Model Name (for Hub)", value="bert-complaint-classifier")
699
+
700
+ train_btn = gr.Button("๐Ÿš€ Start Training", variant="primary")
701
+
702
+ gr.Markdown("---")
703
+
704
+ training_log_output = gr.Textbox(label="Training Logs", lines=20, max_lines=20, interactive=False)
705
+
706
+ with gr.Tab("Predict"):
707
+ gr.Markdown("## ๐Ÿ”ฎ Make Predictions")
708
+ gr.Markdown("Choose a method to classify complaints.")
709
+
710
+ with gr.Tab("Predict Single Text"):
711
+ with gr.Column(variant="panel"):
712
+ gr.Markdown("### Classify a Single Complaint")
713
+
714
+ model_path_input = gr.Textbox(
715
+ label="Model Path or Hub ID",
716
+ value="bert-base-uncased",
717
+ placeholder="e.g., local-model or your_username/your_model"
718
+ )
719
+
720
+ with gr.Row():
721
+ text_input = gr.Textbox(label="Complaint Text", lines=3)
722
+ token_count_output = gr.Markdown("Token count: 0/512")
723
+
724
+ predict_btn = gr.Button("Classify Complaint", variant="primary")
725
+ single_prediction_output = gr.Markdown("Prediction will appear here...")
726
+
727
+ # Link token count to text input
728
+ text_input.change(count_tokens, inputs=text_input, outputs=token_count_output)
729
+
730
+ with gr.Tab("Predict from CSV"):
731
+ with gr.Column(variant="panel"):
732
+ gr.Markdown("### Classify Complaints from a CSV File")
733
+ csv_file_input = gr.File(label="Upload CSV File (with 'complaint' column)")
734
+ csv_model_path = gr.Textbox(
735
+ label="Model Path or Hub ID",
736
+ value="local-model",
737
+ placeholder="e.g., local-model or your_username/your_model"
738
+ )
739
+ csv_predict_btn = gr.Button("Run Predictions on CSV", variant="primary")
740
+ csv_prediction_output = gr.Markdown("Predictions will appear here...")
741
+ download_link = gr.File(label="Download Full Predictions", interactive=False)
742
+
743
+ # Link prediction buttons to functions
744
+ predict_btn.click(
745
+ predict_text,
746
+ inputs=[text_input, model_path_input],
747
+ outputs=single_prediction_output
748
+ )
749
+ csv_predict_btn.click(
750
+ predict_csv,
751
+ inputs=[csv_file_input, csv_model_path],
752
+ outputs=[csv_prediction_output, download_link]
753
+ )
754
+
755
+ with gr.Tab("Tools"):
756
+ gr.Markdown("## ๐Ÿ”ง Tools")
757
+ gr.Markdown("Utilities for managing datasets and models.")
758
+
759
+ with gr.Accordion("Dataset Information"):
760
+ available_datasets = gr.Markdown("No CSV files found in the current directory.")
761
+ refresh_datasets_btn = gr.Button("๐Ÿ”„ Refresh Available Datasets")
762
  gr.Markdown("### Dataset Format Requirements")
763
  gr.Markdown("""
764
  **For training, your CSV file should have:**
 
777
  "Poor TV signal",TV-Radio
778
  ```
779
  """)
 
780
  gr.Markdown("### Model Categories")
781
  categories_info = f"""
782
  **The model classifies complaints into these categories:**
 
789
  """
790
  gr.Markdown(categories_info)
791
 
792
+ with gr.Accordion("Push Local Model to Hub"):
793
+ gr.Markdown("Use this to manually push a locally trained model (`./local-model`) to the Hub.")
794
+ with gr.Row():
795
+ hub_username_input_push = gr.Textbox(label="Hugging Face Username")
796
+ hub_model_name_input_push = gr.Textbox(label="Model Name")
797
+ hub_token_input_push = gr.Textbox(label="Hugging Face Token", type="password")
798
+
799
+ push_btn = gr.Button("๐Ÿš€ Push Model to Hub", variant="primary")
800
+ push_output = gr.verse("Results will appear here...")
801
+
802
+ # Link the push button
803
+ push_btn.click(
804
+ push_to_hub_after_training,
805
+ inputs=[gr.Textbox(value=MODEL_PATH, visible=False), hub_username_input_push, hub_model_name_input_push, hub_token_input_push],
806
+ outputs=push_output
807
+ )
808
+
809
+ # All button clicks and UI logic now correctly indented within the app block
810
+ preview_btn.click(
811
+ preview_dataset,
812
+ inputs=[uploaded_file, text_column_input, label_column_input],
813
+ outputs=preview_output
814
+ )
815
 
816
+ train_btn.click(
817
+ train_model_inline,
818
+ inputs=[
819
+ uploaded_file,
820
+ text_column_input,
821
+ label_column_input,
822
+ num_epochs_slider,
823
+ batch_size_slider,
824
+ learning_rate_slider,
825
+ hf_token_input,
826
+ push_to_hub_checkbox,
827
+ hf_username_input,
828
+ hf_model_name_input,
829
+ ],
830
+ outputs=training_log_output,
831
+ )
832
+
833
+ refresh_datasets_btn.click(
834
+ display_available_datasets,
835
+ outputs=available_datasets
836
+ )
837
 
838
+ # Show datasets on load
839
+ app.load(display_available_datasets, outputs=available_datasets)
840
+
841
+ # Launch the app
842
+ if __name__ == "__main__":
843
  app.launch(
844
  server_name="0.0.0.0",
845
  server_port=7860,
846
  share=False,
847
  show_error=True
848
+ )