LSTM-CRF_Train / app.py
aagamjtdev's picture
add download button
f17ce9b
# import os
# import shutil
# import tempfile
# import gradio as gr
# from huggingface_hub import hf_hub_download, upload_file, HfApi
# import sys
#
# # Add current directory to path to import train_model
# sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
#
# # Configuration
# OUTPUT_DIR = "output_data"
# MODEL_FILE = "model_enhanced.pt"
# VOCAB_FILE = "vocabs_enhanced.pkl"
# CHECKPOINT_FILE = "checkpoint_enhanced.pt"
#
# # IMPORTANT: Update this with your actual Hugging Face repository ID
# REPO_ID = "heerjtdev/LSTM_CRF" # Replace with your repo ID
# # HF_TOKEN = os.environ.get("HF_TOKEN") # Set this as a secret in your Space settings
#
#
# def download_existing_models():
# """Download existing model files from the Hugging Face Hub if available."""
# try:
# api = HfApi()
# #files = api.list_repo_files(REPO_ID, token=HF_TOKEN)
# files = api.list_repo_files(REPO_ID)
#
# os.makedirs(OUTPUT_DIR, exist_ok=True)
#
# downloaded_files = []
#
# # Download model file
# if MODEL_FILE in files:
# print(f"πŸ“₯ Downloading {MODEL_FILE} from Hub...")
# model_path = hf_hub_download(
# repo_id=REPO_ID,
# filename=MODEL_FILE,
# # token=HF_TOKEN,
# local_dir=OUTPUT_DIR,
# force_download=True # Always get latest version
# )
# downloaded_files.append(MODEL_FILE)
# print(f"βœ… Downloaded {MODEL_FILE}")
#
# # Download vocab file
# if VOCAB_FILE in files:
# print(f"πŸ“₯ Downloading {VOCAB_FILE} from Hub...")
# vocab_path = hf_hub_download(
# repo_id=REPO_ID,
# filename=VOCAB_FILE,
# # token=HF_TOKEN,
# local_dir=OUTPUT_DIR,
# force_download=True # Always get latest version
# )
# downloaded_files.append(VOCAB_FILE)
# print(f"βœ… Downloaded {VOCAB_FILE}")
#
# # Download checkpoint file (optional, for resuming training)
# if CHECKPOINT_FILE in files:
# print(f"πŸ“₯ Downloading {CHECKPOINT_FILE} from Hub...")
# checkpoint_path = hf_hub_download(
# repo_id=REPO_ID,
# filename=CHECKPOINT_FILE,
# # token=HF_TOKEN,
# local_dir=OUTPUT_DIR,
# force_download=True
# )
# downloaded_files.append(CHECKPOINT_FILE)
# print(f"βœ… Downloaded {CHECKPOINT_FILE}")
#
# if downloaded_files:
# return f"βœ… Downloaded from Hub: {', '.join(downloaded_files)}"
# else:
# return "ℹ️ No existing model files found in repository. Starting fresh."
# except Exception as e:
# error_msg = f"⚠️ Could not download existing models: {str(e)}"
# print(error_msg)
# return error_msg
#
#
# def train_model(dataset_file, progress=gr.Progress()):
# """Train the model with the uploaded dataset."""
# if dataset_file is None:
# return "❌ Please upload a dataset file!", None, None
#
# try:
# # Step 1: Download existing models from Hub (if any) BEFORE training starts
# progress(0.05, desc="Checking Hugging Face Hub for existing models...")
# download_status = download_existing_models()
# status_log = f"{download_status}\n\n"
# yield status_log, None, None
#
# # Step 2: Save uploaded file
# progress(0.1, desc="Processing uploaded dataset...")
# dataset_path = dataset_file.name
# status_log += f"πŸ“‚ Dataset uploaded: {os.path.basename(dataset_path)}\n\n"
# yield status_log, None, None
#
# # Step 3: Import and run training
# progress(0.15, desc="Initializing training...")
# status_log += "πŸš€ Starting training...\n"
# status_log += "πŸ“Š This may take a while. Training progress will appear in the terminal.\n\n"
# yield status_log, None, None
#
# # Import the training module
# try:
# import train_model as tm
# print("=" * 80)
# print("TRAINING STARTED")
# print("=" * 80)
#
# # Run training - this will handle model loading internally
# progress(0.2, desc="Training in progress... (check terminal for details)")
# tm.train_from_json(dataset_path)
#
# print("=" * 80)
# print("TRAINING COMPLETED")
# print("=" * 80)
#
# status_log += "βœ… Training completed successfully!\n\n"
# yield status_log, None, None
#
# except ImportError as ie:
# error_msg = f"❌ Failed to import training module: {str(ie)}\n"
# error_msg += "Make sure train_model.py is in the same directory as app.py"
# yield status_log + error_msg, None, None
# return
# except Exception as train_error:
# error_msg = f"❌ Training failed with error:\n{str(train_error)}\n"
# yield status_log + error_msg, None, None
# return
#
# # Step 4: Verify files exist
# progress(0.85, desc="Verifying trained model files...")
# model_path = os.path.join(OUTPUT_DIR, MODEL_FILE)
# vocab_path = os.path.join(OUTPUT_DIR, VOCAB_FILE)
# checkpoint_path = os.path.join(OUTPUT_DIR, CHECKPOINT_FILE)
#
# files_exist = []
# if os.path.exists(model_path):
# files_exist.append(MODEL_FILE)
# if os.path.exists(vocab_path):
# files_exist.append(VOCAB_FILE)
#
# if not files_exist:
# error_msg = "❌ Error: Model files were not created. Check training logs."
# yield status_log + error_msg, None, None
# return
#
# status_log += f"βœ… Found trained files: {', '.join(files_exist)}\n\n"
# yield status_log, None, None
#
# # Step 5: Upload to Hub
# progress(0.9, desc="Uploading models to Hugging Face Hub...")
# status_log += "☁️ Uploading to Hugging Face Hub...\n"
# yield status_log, None, None
#
# upload_status = []
#
# if os.path.exists(model_path):
# try:
# upload_file(
# path_or_fileobj=model_path,
# path_in_repo=MODEL_FILE,
# repo_id=REPO_ID,
# # token=HF_TOKEN,
# commit_message="Update trained model"
# )
# upload_status.append(MODEL_FILE)
# print(f"βœ… Uploaded {MODEL_FILE} to Hub")
# except Exception as e:
# print(f"⚠️ Failed to upload {MODEL_FILE}: {e}")
#
# if os.path.exists(vocab_path):
# try:
# upload_file(
# path_or_fileobj=vocab_path,
# path_in_repo=VOCAB_FILE,
# repo_id=REPO_ID,
# # token=HF_TOKEN,
# commit_message="Update vocabulary"
# )
# upload_status.append(VOCAB_FILE)
# print(f"βœ… Uploaded {VOCAB_FILE} to Hub")
# except Exception as e:
# print(f"⚠️ Failed to upload {VOCAB_FILE}: {e}")
#
# # Also upload checkpoint for future resume capability
# if os.path.exists(checkpoint_path):
# try:
# upload_file(
# path_or_fileobj=checkpoint_path,
# path_in_repo=CHECKPOINT_FILE,
# repo_id=REPO_ID,
# # token=HF_TOKEN,
# commit_message="Update checkpoint"
# )
# upload_status.append(CHECKPOINT_FILE)
# print(f"βœ… Uploaded {CHECKPOINT_FILE} to Hub")
# except Exception as e:
# print(f"⚠️ Failed to upload {CHECKPOINT_FILE}: {e}")
#
# if upload_status:
# status_log += f"βœ… Uploaded to Hub: {', '.join(upload_status)}\n\n"
# else:
# status_log += "⚠️ Warning: No files were uploaded to Hub\n\n"
#
# yield status_log, None, None
#
# # Step 6: Copy to temp directory for download
# progress(0.95, desc="Preparing download files...")
# temp_dir = tempfile.mkdtemp()
#
# model_download = None
# vocab_download = None
#
# if os.path.exists(model_path):
# temp_model = os.path.join(temp_dir, MODEL_FILE)
# shutil.copy2(model_path, temp_model)
# model_download = temp_model
# print(f"πŸ“¦ Prepared {MODEL_FILE} for download")
#
# if os.path.exists(vocab_path):
# temp_vocab = os.path.join(temp_dir, VOCAB_FILE)
# shutil.copy2(vocab_path, temp_vocab)
# vocab_download = temp_vocab
# print(f"πŸ“¦ Prepared {VOCAB_FILE} for download")
#
# progress(1.0, desc="Complete!")
#
# status_log += "πŸ“¦ Files ready for download below!\n"
# status_log += "\n" + "=" * 50 + "\n"
# status_log += "TRAINING COMPLETE - You can now download the model files\n"
# status_log += "=" * 50
#
# yield status_log, model_download, vocab_download
#
# except Exception as e:
# error_msg = f"❌ Unexpected error: {str(e)}\n"
# import traceback
# error_msg += f"\nTraceback:\n{traceback.format_exc()}"
# yield error_msg, None, None
#
#
# def download_models_from_hub():
# """Download the latest models from the Hugging Face Hub."""
# try:
# os.makedirs(OUTPUT_DIR, exist_ok=True)
#
# api = HfApi()
# #files = api.list_repo_files(REPO_ID, token=HF_TOKEN)
# files = api.list_repo_files(REPO_ID)
#
# downloaded_files = []
#
# # Download model
# if MODEL_FILE in files:
# print(f"πŸ“₯ Downloading {MODEL_FILE} from Hub...")
# model_path = hf_hub_download(
# repo_id=REPO_ID,
# filename=MODEL_FILE,
# # token=HF_TOKEN,
# local_dir=OUTPUT_DIR,
# force_download=True
# )
# downloaded_files.append(MODEL_FILE)
# else:
# return f"❌ {MODEL_FILE} not found in repository", None, None
#
# # Download vocab
# if VOCAB_FILE in files:
# print(f"πŸ“₯ Downloading {VOCAB_FILE} from Hub...")
# vocab_path = hf_hub_download(
# repo_id=REPO_ID,
# filename=VOCAB_FILE,
# # token=HF_TOKEN,
# local_dir=OUTPUT_DIR,
# force_download=True
# )
# downloaded_files.append(VOCAB_FILE)
# else:
# return f"❌ {VOCAB_FILE} not found in repository", None, None
#
# # Copy to temp for download
# temp_dir = tempfile.mkdtemp()
# temp_model = os.path.join(temp_dir, MODEL_FILE)
# temp_vocab = os.path.join(temp_dir, VOCAB_FILE)
#
# shutil.copy2(os.path.join(OUTPUT_DIR, MODEL_FILE), temp_model)
# shutil.copy2(os.path.join(OUTPUT_DIR, VOCAB_FILE), temp_vocab)
#
# success_msg = f"βœ… Successfully downloaded from Hub:\n"
# success_msg += f" β€’ {MODEL_FILE}\n"
# success_msg += f" β€’ {VOCAB_FILE}\n\n"
# success_msg += "πŸ“¦ Files are ready to download below!"
#
# return success_msg, temp_model, temp_vocab
#
# except Exception as e:
# error_msg = f"❌ Error downloading models: {str(e)}\n\n"
# error_msg += f"Make sure:\n"
# error_msg += f"1. REPO_ID is set correctly: {REPO_ID}\n"
# error_msg += f"2. HF_TOKEN is set in Space secrets\n"
# error_msg += f"3. Model files exist in the repository"
# return error_msg, None, None
#
#
# # Create Gradio interface
# with gr.Blocks(title="MCQ Structure Extraction - Model Training", theme=gr.themes.Soft()) as demo:
# gr.Markdown(
# """
# # πŸŽ“ MCQ Structure Extraction - Model Training
#
# Train a BiLSTM-CRF model with deep layout understanding for extracting structured information from MCQ documents.
#
# ## πŸ“‹ Instructions:
# 1. **Upload Dataset**: Provide your unified JSON file containing tokens, bounding boxes, and labels
# 2. **Train Model**: Click "Start Training" and wait for completion (this may take a while)
# 3. **Download Models**: Once training is complete, download the trained model and vocabulary files
#
# ## πŸ“₯ Or Download Existing Models:
# If you just want to download the latest trained models from the repository, use the "Download from Hub" tab.
#
# ---
# """
# )
#
# with gr.Tab("πŸš€ Train New Model"):
# gr.Markdown(
# """
# ### Training Process:
# The app will automatically:
# 1. βœ… Download any existing models from Hugging Face Hub (for resuming training)
# 2. 🎯 Train the model on your uploaded dataset
# 3. ☁️ Upload the trained models back to the Hub
# 4. πŸ“₯ Provide download links for the trained files
#
# **Note**: Training progress details appear in the terminal/logs. The status box shows major milestones.
# """
# )
#
# with gr.Row():
# with gr.Column():
# dataset_input = gr.File(
# label="πŸ“‚ Upload Training Dataset (JSON)",
# file_types=[".json"],
# type="filepath"
# )
# train_button = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
#
# with gr.Column():
# status_output = gr.Textbox(
# label="πŸ“Š Training Status",
# lines=12,
# interactive=False,
# show_copy_button=True
# )
#
# gr.Markdown("### πŸ“¦ Download Trained Models")
# with gr.Row():
# model_output = gr.File(label="πŸ’Ύ Model File (.pt)")
# vocab_output = gr.File(label="πŸ“š Vocabulary File (.pkl)")
#
# train_button.click(
# fn=train_model,
# inputs=[dataset_input],
# outputs=[status_output, model_output, vocab_output]
# )
#
# with gr.Tab("☁️ Download from Hub"):
# gr.Markdown(
# """
# ### Download Pre-trained Models
#
# Download the latest trained models directly from your Hugging Face repository.
# This is useful if:
# - You want to use pre-trained models without training
# - You need to download models trained in a previous session
# - You want to get the latest version from the Hub
#
# The downloaded files can be used for inference with your MCQ extraction pipeline.
# """
# )
#
# download_button = gr.Button("☁️ Download Latest Models from Hub", variant="primary", size="lg")
#
# download_status = gr.Textbox(
# label="Download Status",
# lines=6,
# interactive=False,
# show_copy_button=True
# )
#
# gr.Markdown("### πŸ“¦ Downloaded Files")
# with gr.Row():
# hub_model_output = gr.File(label="πŸ’Ύ Model File (.pt)")
# hub_vocab_output = gr.File(label="πŸ“š Vocabulary File (.pkl)")
#
# download_button.click(
# fn=download_models_from_hub,
# outputs=[download_status, hub_model_output, hub_vocab_output]
# )
#
# gr.Markdown(
# """
# ---
# ### βš™οΈ Model Configuration:
#
# **Architecture:**
# - BiLSTM-CRF with spatial attention mechanism
# - Word embeddings + Character-level CNN
# - Bounding box encoding with MLP
# - Spatial & context feature extraction
# - Learnable positional embeddings
#
# **Features Used:**
# - Token text (word-level and character-level)
# - Bounding box coordinates (normalized)
# - Spatial features: vertical spacing, alignment, dimensions (11 features)
# - Context features: surrounding question/option markers (8 features)
#
# **Output Labels (13 total):**
# - Questions, Options, Answers, Images, Section Headings, Passages (BIO tagging)
#
# **Training Parameters:**
# - Batch Size: 8
# - Epochs: 10 (with early stopping after 10 epochs without improvement)
# - Learning Rate: 5e-4 (AdamW optimizer with OneCycleLR scheduler)
# - Hidden Size: 768
# - Total Parameters: ~15.6M
#
# **Hardware Requirements:**
# - GPU recommended for reasonable training speed
# - CPU training supported but significantly slower
#
# ---
#
#
#
# **Environment Variables Required:**
# - `SPACE_ID`: Your Hugging Face Space/Repo ID (auto-set in Spaces)
# - `HF_TOKEN`: Your Hugging Face write token (set as a secret)
#
# **Model Persistence:**
# - Models are automatically saved to `output_data/` directory
# - Best model is uploaded to Hugging Face Hub after each improvement
# - Training can be resumed from checkpoints
# """
# )
#
# # Launch the app
# if __name__ == "__main__":
# demo.launch()
import os
import shutil
import tempfile
import gradio as gr
from huggingface_hub import hf_hub_download, upload_file, HfApi
import sys
import glob
# Add current directory to path to import train_model
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Configuration
OUTPUT_DIR = "output_data"
MODEL_FILE = "model_enhanced.pt"
VOCAB_FILE = "vocabs_enhanced.pkl"
CHECKPOINT_FILE = "checkpoint_enhanced.pt"
# IMPORTANT: Update this with your actual Hugging Face repository ID
REPO_ID = "heerjtdev/LSTM_CRF" # Replace with your repo ID
# HF_TOKEN = os.environ.get("HF_TOKEN") # Set this as a secret in your Space settings
def download_existing_models():
"""Download existing model files from the Hugging Face Hub if available."""
try:
api = HfApi()
# files = api.list_repo_files(REPO_ID, token=HF_TOKEN)
files = api.list_repo_files(REPO_ID)
os.makedirs(OUTPUT_DIR, exist_ok=True)
downloaded_files = []
# Download model file
if MODEL_FILE in files:
print(f"πŸ“₯ Downloading {MODEL_FILE} from Hub...")
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILE,
# token=HF_TOKEN,
local_dir=OUTPUT_DIR,
force_download=True # Always get latest version
)
downloaded_files.append(MODEL_FILE)
print(f"βœ… Downloaded {MODEL_FILE}")
# Download vocab file
if VOCAB_FILE in files:
print(f"πŸ“₯ Downloading {VOCAB_FILE} from Hub...")
vocab_path = hf_hub_download(
repo_id=REPO_ID,
filename=VOCAB_FILE,
# token=HF_TOKEN,
local_dir=OUTPUT_DIR,
force_download=True # Always get latest version
)
downloaded_files.append(VOCAB_FILE)
print(f"βœ… Downloaded {VOCAB_FILE}")
# Download checkpoint file (optional, for resuming training)
if CHECKPOINT_FILE in files:
print(f"πŸ“₯ Downloading {CHECKPOINT_FILE} from Hub...")
checkpoint_path = hf_hub_download(
repo_id=REPO_ID,
filename=CHECKPOINT_FILE,
# token=HF_TOKEN,
local_dir=OUTPUT_DIR,
force_download=True
)
downloaded_files.append(CHECKPOINT_FILE)
print(f"βœ… Downloaded {CHECKPOINT_FILE}")
if downloaded_files:
return f"βœ… Downloaded from Hub: {', '.join(downloaded_files)}"
else:
return "ℹ️ No existing model files found in repository. Starting fresh."
except Exception as e:
error_msg = f"⚠️ Could not download existing models: {str(e)}"
print(error_msg)
return error_msg
def train_model(dataset_file, progress=gr.Progress()):
"""Train the model with the uploaded dataset."""
if dataset_file is None:
return "❌ Please upload a dataset file!", None, None
try:
# Step 1: Download existing models from Hub (if any) BEFORE training starts
progress(0.05, desc="Checking Hugging Face Hub for existing models...")
download_status = download_existing_models()
status_log = f"{download_status}\n\n"
# Reset download outputs before training starts
yield status_log, None, None, None, None
# Step 2: Save uploaded file
progress(0.1, desc="Processing uploaded dataset...")
dataset_path = dataset_file.name
status_log += f"πŸ“‚ Dataset uploaded: {os.path.basename(dataset_path)}\n\n"
yield status_log, None, None, None, None
# Step 3: Import and run training
progress(0.15, desc="Initializing training...")
status_log += "πŸš€ Starting training...\n"
status_log += "πŸ“Š This may take a while. Training progress will appear in the terminal.\n\n"
yield status_log, None, None, None, None
# Import the training module
try:
import train_model as tm
print("=" * 80)
print("TRAINING STARTED")
print("=" * 80)
# Run training - this will handle model loading internally
progress(0.2, desc="Training in progress... (check terminal for details)")
tm.train_from_json(dataset_path)
print("=" * 80)
print("TRAINING COMPLETED")
print("=" * 80)
status_log += "βœ… Training completed successfully!\n\n"
yield status_log, None, None, None, None
except ImportError as ie:
error_msg = f"❌ Failed to import training module: {str(ie)}\n"
error_msg += "Make sure train_model.py is in the same directory as app.py"
yield status_log + error_msg, None, None, None, None
return
except Exception as train_error:
error_msg = f"❌ Training failed with error:\n{str(train_error)}\n"
yield status_log + error_msg, None, None, None, None
return
# Step 4: Verify files exist
progress(0.85, desc="Verifying trained model files...")
model_path = os.path.join(OUTPUT_DIR, MODEL_FILE)
vocab_path = os.path.join(OUTPUT_DIR, VOCAB_FILE)
checkpoint_path = os.path.join(OUTPUT_DIR, CHECKPOINT_FILE)
files_exist = []
if os.path.exists(model_path):
files_exist.append(MODEL_FILE)
if os.path.exists(vocab_path):
files_exist.append(VOCAB_FILE)
if not files_exist:
error_msg = "❌ Error: Model files were not created. Check training logs."
yield status_log + error_msg, None, None, None, None
return
status_log += f"βœ… Found trained files: {', '.join(files_exist)}\n\n"
yield status_log, None, None, None, None
# Step 5: Upload to Hub
progress(0.9, desc="Uploading models to Hugging Face Hub...")
status_log += "☁️ Uploading to Hugging Face Hub...\n"
yield status_log, None, None, None, None
upload_status = []
if os.path.exists(model_path):
try:
upload_file(
path_or_fileobj=model_path,
path_in_repo=MODEL_FILE,
repo_id=REPO_ID,
# token=HF_TOKEN,
commit_message="Update trained model"
)
upload_status.append(MODEL_FILE)
print(f"βœ… Uploaded {MODEL_FILE} to Hub")
except Exception as e:
print(f"⚠️ Failed to upload {MODEL_FILE}: {e}")
if os.path.exists(vocab_path):
try:
upload_file(
path_or_fileobj=vocab_path,
path_in_repo=VOCAB_FILE,
repo_id=REPO_ID,
# token=HF_TOKEN,
commit_message="Update vocabulary"
)
upload_status.append(VOCAB_FILE)
print(f"βœ… Uploaded {VOCAB_FILE} to Hub")
except Exception as e:
print(f"⚠️ Failed to upload {VOCAB_FILE}: {e}")
# Also upload checkpoint for future resume capability
if os.path.exists(checkpoint_path):
try:
upload_file(
path_or_fileobj=checkpoint_path,
path_in_repo=CHECKPOINT_FILE,
repo_id=REPO_ID,
# token=HF_TOKEN,
commit_message="Update checkpoint"
)
upload_status.append(CHECKPOINT_FILE)
print(f"βœ… Uploaded {CHECKPOINT_FILE} to Hub")
except Exception as e:
print(f"⚠️ Failed to upload {CHECKPOINT_FILE}: {e}")
if upload_status:
status_log += f"βœ… Uploaded to Hub: {', '.join(upload_status)}\n\n"
else:
status_log += "⚠️ Warning: No files were uploaded to Hub\n\n"
yield status_log, None, None, None, None
# Step 6: Copy to temp directory for download
progress(0.95, desc="Preparing download files...")
temp_dir = tempfile.mkdtemp()
model_download = None
vocab_download = None
if os.path.exists(model_path):
temp_model = os.path.join(temp_dir, MODEL_FILE)
shutil.copy2(model_path, temp_model)
model_download = temp_model
print(f"πŸ“¦ Prepared {MODEL_FILE} for download")
if os.path.exists(vocab_path):
temp_vocab = os.path.join(temp_dir, VOCAB_FILE)
shutil.copy2(vocab_path, temp_vocab)
vocab_download = temp_vocab
print(f"πŸ“¦ Prepared {VOCAB_FILE} for download")
progress(1.0, desc="Complete!")
status_log += "πŸ“¦ Files ready for download below!\n"
status_log += "\n" + "=" * 50 + "\n"
status_log += "TRAINING COMPLETE - You can now download the model files\n"
status_log += "=" * 50
# Note: We return the model_download and vocab_download twice for both sets of File outputs
yield status_log, model_download, vocab_download, model_download, vocab_download
except Exception as e:
error_msg = f"❌ Unexpected error: {str(e)}\n"
import traceback
error_msg += f"\nTraceback:\n{traceback.format_exc()}"
# Return Nones for all file outputs
yield error_msg, None, None, None, None
def download_models_from_hub():
"""Download the latest models from the Hugging Face Hub."""
try:
os.makedirs(OUTPUT_DIR, exist_ok=True)
api = HfApi()
# files = api.list_repo_files(REPO_ID, token=HF_TOKEN)
files = api.list_repo_files(REPO_ID)
downloaded_files = []
# Download model
if MODEL_FILE in files:
print(f"πŸ“₯ Downloading {MODEL_FILE} from Hub...")
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILE,
# token=HF_TOKEN,
local_dir=OUTPUT_DIR,
force_download=True
)
downloaded_files.append(MODEL_FILE)
else:
return f"❌ {MODEL_FILE} not found in repository", None, None, None, None
# Download vocab
if VOCAB_FILE in files:
print(f"πŸ“₯ Downloading {VOCAB_FILE} from Hub...")
vocab_path = hf_hub_download(
repo_id=REPO_ID,
filename=VOCAB_FILE,
# token=HF_TOKEN,
local_dir=OUTPUT_DIR,
force_download=True
)
downloaded_files.append(VOCAB_FILE)
else:
return f"❌ {VOCAB_FILE} not found in repository", None, None, None, None
# Copy to temp for download
temp_dir = tempfile.mkdtemp()
temp_model = os.path.join(temp_dir, MODEL_FILE)
temp_vocab = os.path.join(temp_dir, VOCAB_FILE)
shutil.copy2(os.path.join(OUTPUT_DIR, MODEL_FILE), temp_model)
shutil.copy2(os.path.join(OUTPUT_DIR, VOCAB_FILE), temp_vocab)
success_msg = f"βœ… Successfully downloaded from Hub:\n"
success_msg += f" β€’ {MODEL_FILE}\n"
success_msg += f" β€’ {VOCAB_FILE}\n\n"
success_msg += "πŸ“¦ Files are ready to download below!"
# Return the downloaded files for both sets of file outputs
return success_msg, temp_model, temp_vocab, temp_model, temp_vocab
except Exception as e:
error_msg = f"❌ Error downloading models: {str(e)}\n\n"
error_msg += f"Make sure:\n"
error_msg += f"1. REPO_ID is set correctly: {REPO_ID}\n"
error_msg += f"2. HF_TOKEN is set in Space secrets\n"
error_msg += f"3. Model files exist in the repository"
return error_msg, None, None, None, None
# --- UPDATED check_local_files FUNCTION ---
def check_local_files():
"""
Checks and reports the files present in the local output directory.
If core model files exist, it prepares and returns them for download.
"""
if not os.path.exists(OUTPUT_DIR):
return f"ℹ️ Directory **'{OUTPUT_DIR}'** does not exist.", None, None
all_files = os.listdir(OUTPUT_DIR)
model_path = os.path.join(OUTPUT_DIR, MODEL_FILE)
vocab_path = os.path.join(OUTPUT_DIR, VOCAB_FILE)
model_download = None
vocab_download = None
# 1. Prepare download paths if files exist
if os.path.exists(model_path):
model_download = model_path
if os.path.exists(vocab_path):
vocab_download = vocab_path
# 2. Generate status message
if not all_files:
return f"ℹ️ Directory **'{OUTPUT_DIR}'** is empty.", None, None
file_list = []
total_size = 0
# Sort files to put core files first
sorted_files = sorted(all_files, key=lambda x: (x != MODEL_FILE, x != VOCAB_FILE, x != CHECKPOINT_FILE, x))
for filename in sorted_files:
filepath = os.path.join(OUTPUT_DIR, filename)
if os.path.isfile(filepath):
size_bytes = os.path.getsize(filepath)
total_size += size_bytes
# Simple size formatting
if size_bytes > 1024 * 1024:
size_str = f"{size_bytes / (1024 * 1024):.2f} MB"
elif size_bytes > 1024:
size_str = f"{size_bytes / 1024:.2f} KB"
else:
size_str = f"{size_bytes} bytes"
file_list.append(f"β€’ **{filename}** (Size: {size_str})")
# Format total size
if total_size > 1024 * 1024 * 1024:
total_size_str = f"{total_size / (1024 * 1024 * 1024):.2f} GB"
elif total_size > 1024 * 1024:
total_size_str = f"{total_size / (1024 * 1024):.2f} MB"
else:
total_size_str = f"{total_size / 1024:.2f} KB"
header = f"βœ… Contents of **'{OUTPUT_DIR}'** ({len(file_list)} files, Total Size: {total_size_str}):\n"
if model_download and vocab_download:
header += "\n**πŸ“¦ Core model files found! Ready for download below.**"
elif model_download or vocab_download:
header += "\n**⚠️ Found some model files, but not both.**"
return header + "\n" + "\n".join(file_list), model_download, vocab_download
def clear_local_memory():
"""Deletes the local output directory and its contents."""
if os.path.exists(OUTPUT_DIR):
try:
shutil.rmtree(OUTPUT_DIR)
return f"πŸ—‘οΈ Successfully deleted local directory **'{OUTPUT_DIR}'** and all its contents. Memory cleared.", None, None
except Exception as e:
return f"❌ Error clearing memory (deleting '{OUTPUT_DIR}'): {str(e)}", None, None
else:
return f"ℹ️ Local directory **'{OUTPUT_DIR}'** does not exist. No memory to clear.", None, None
# --- END NEW FUNCTIONS ---
# Create Gradio interface
with gr.Blocks(title="MCQ Structure Extraction - Model Training", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸŽ“ MCQ Structure Extraction - Model Training
Train a BiLSTM-CRF model with deep layout understanding for extracting structured information from MCQ documents.
## πŸ“‹ Instructions:
1. **Upload Dataset**: Provide your unified JSON file containing tokens, bounding boxes, and labels
2. **Train Model**: Click "Start Training" and wait for completion (this may take a while)
3. **Download Models**: Once training is complete, download the trained model and vocabulary files
## πŸ“₯ Or Download Existing Models:
If you just want to download the latest trained models from the repository, use the "Download from Hub" tab.
---
"""
)
# Define common File components for outputs
download_model_output = gr.File(label="πŸ’Ύ Model File (.pt)", interactive=False)
download_vocab_output = gr.File(label="πŸ“š Vocabulary File (.pkl)", interactive=False)
# We need a dummy set of outputs to clear the download boxes when starting training,
# and a permanent set for the utility functions. We'll use the permanent ones below.
with gr.Tab("πŸš€ Train New Model"):
gr.Markdown(
"""
### Training Process:
The app will automatically:
1. βœ… Download any existing models from Hugging Face Hub (for resuming training)
2. 🎯 Train the model on your uploaded dataset
3. ☁️ Upload the trained models back to the Hub
4. πŸ“₯ Provide download links for the trained files
**Note**: Training progress details appear in the terminal/logs. The status box shows major milestones.
"""
)
with gr.Row():
with gr.Column():
dataset_input = gr.File(
label="πŸ“‚ Upload Training Dataset (JSON)",
file_types=[".json"],
type="filepath"
)
train_button = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
# --- NEW BUTTONS for utility ---
with gr.Row():
check_button = gr.Button("πŸ”Ž Check Local Models", variant="secondary")
clear_button = gr.Button("🧹 Clear Local Memory", variant="stop")
# ------------------------------
with gr.Column():
status_output = gr.Textbox(
label="πŸ“Š Training/Utility Status",
lines=12,
interactive=False,
show_copy_button=True
)
gr.Markdown("### πŸ“¦ Download Trained/Local Models")
with gr.Row():
# Use the defined components for the training output
train_model_output = download_model_output
train_vocab_output = download_vocab_output
# Note: The train_model function now returns 5 values (status, model_file, vocab_file, model_file_again, vocab_file_again)
# We target the two download outputs directly for the final model and vocab files.
train_button.click(
fn=train_model,
inputs=[dataset_input],
outputs=[status_output, train_model_output, train_vocab_output, download_model_output,
download_vocab_output]
)
# --- NEW BUTTON ACTIONS ---
# check_local_files now returns (status, model_download_path, vocab_download_path)
# We target the status output AND the two global download outputs
check_button.click(
fn=check_local_files,
inputs=[],
outputs=[status_output, download_model_output, download_vocab_output]
)
# clear_local_memory now returns (status, None, None) to clear the download boxes
clear_button.click(
fn=clear_local_memory,
inputs=[],
outputs=[status_output, download_model_output, download_vocab_output]
)
# --------------------------
with gr.Tab("☁️ Download from Hub"):
gr.Markdown(
"""
### Download Pre-trained Models
Download the latest trained models directly from your Hugging Face repository.
"""
)
download_button = gr.Button("☁️ Download Latest Models from Hub", variant="primary", size="lg")
download_status = gr.Textbox(
label="Download Status",
lines=6,
interactive=False,
show_copy_button=True
)
gr.Markdown("### πŸ“¦ Downloaded Files")
with gr.Row():
# Use the defined components for the Hub output
hub_model_output = download_model_output
hub_vocab_output = download_vocab_output
# Note: The download_models_from_hub function now returns 5 values (status, model_file, vocab_file, model_file_again, vocab_file_again)
# We target the two download outputs directly for the final model and vocab files.
download_button.click(
fn=download_models_from_hub,
outputs=[download_status, hub_model_output, hub_vocab_output, download_model_output, download_vocab_output]
)
gr.Markdown(
"""
---
### βš™οΈ Model Configuration:
... (rest of the markdown)
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch()