smarthillc commited on
Commit
d96cedb
·
1 Parent(s): 0e386ca

Add training app with Flan-T5 implementation and datasets

Browse files
README.md CHANGED
@@ -1,13 +1,48 @@
1
  ---
2
- title: Win Stack
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.37.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: Winstack
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Resume Normalizer Trainer
3
+ emoji: 📝
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ hardware: 4xL4
12
  ---
13
 
14
+ # Resume Normalizer Trainer
15
+
16
+ Fine-tune a Flan-T5 model for resume entity normalization and deduplication.
17
+
18
+ ## Features
19
+
20
+ - **Company Name Normalization**: Handle mergers, acquisitions, and rebranding (e.g., "Facebook" → "Meta Platforms Inc.")
21
+ - **Job Title Standardization**: Recognize equivalent roles and seniority levels (e.g., "SWE" → "Software Engineer")
22
+ - **Skills Normalization**: Standardize technology names and abbreviations (e.g., "JS" → "JavaScript")
23
+ - **Binary Equivalency Detection**: Determine if two entities refer to the same thing
24
+
25
+ ## Model Details
26
+
27
+ - **Base Model**: Google Flan-T5 (instruction-tuned for better zero-shot performance)
28
+ - **Fine-tuning Method**: LoRA (Low-Rank Adaptation) for efficient training
29
+ - **Parameters**: 250M (T5-Base) or 770M (T5-Large)
30
+ - **Training Data**: 9,302 high-quality examples (478 manual + 8,824 synthetic)
31
+
32
+ ## Usage
33
+
34
+ 1. Check that training data is available using the "Check Data" tab
35
+ 2. Enter your HuggingFace token and username
36
+ 3. Select model size and training epochs
37
+ 4. Click "Start Training" and monitor progress in the "Training Status" tab
38
+ 5. Once complete, your model will be available on HuggingFace Hub
39
+
40
+ ## Expected Performance
41
+
42
+ - **Inference Speed**: <100ms per query
43
+ - **Accuracy**: >90% on entity normalization tasks
44
+ - **Memory Usage**: ~1GB (T5-Base) or ~3GB (T5-Large)
45
+
46
+ ## Hardware Requirements
47
+
48
+ This Space runs on 4xL4 GPUs (96GB total VRAM) for efficient distributed training.
app.py CHANGED
@@ -1,7 +1,252 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import subprocess
4
+ import threading
5
+ import time
6
 
7
+ # Global variable to track training status
8
+ training_status = {"status": "idle", "message": "", "progress": 0}
9
 
10
+ def check_data():
11
+ """Check if data is available"""
12
+ files = []
13
+ if os.path.exists("combined_final_training_data.csv"):
14
+ files.append("✅ Combined dataset: 9,302 examples")
15
+ if os.path.exists("combined_balanced_training_data.csv"):
16
+ files.append("✅ Balanced dataset: 8,304 examples")
17
+ if os.path.exists("data/clean_training_data.csv"):
18
+ files.append("✅ Clean manual data: 478 examples")
19
+
20
+ if not files:
21
+ return "❌ No training data found. Please upload data files."
22
+
23
+ return "\n".join(files)
24
+
25
+ def run_training_subprocess(hf_token, model_size, hub_username, num_epochs, use_balanced):
26
+ """Run training in subprocess"""
27
+ global training_status
28
+
29
+ try:
30
+ # Determine which data file to use
31
+ if use_balanced and os.path.exists("combined_balanced_training_data.csv"):
32
+ data_path = "combined_balanced_training_data.csv"
33
+ elif os.path.exists("combined_final_training_data.csv"):
34
+ data_path = "combined_final_training_data.csv"
35
+ else:
36
+ training_status["status"] = "error"
37
+ training_status["message"] = "No training data found!"
38
+ return
39
+
40
+ # Determine model size
41
+ size = "base" if "Base" in model_size else "large"
42
+
43
+ # Build command
44
+ cmd = [
45
+ "python", "train.py",
46
+ "--data_path", data_path,
47
+ "--model_size", size,
48
+ "--num_epochs", str(num_epochs),
49
+ "--use_lora" # Always use LoRA for efficiency
50
+ ]
51
+
52
+ if hf_token:
53
+ cmd.extend(["--hf_token", hf_token])
54
+ if hub_username:
55
+ cmd.extend(["--hub_username", hub_username])
56
+
57
+ training_status["status"] = "running"
58
+ training_status["message"] = "Starting training..."
59
+ training_status["progress"] = 0
60
+
61
+ # Run training
62
+ process = subprocess.Popen(
63
+ cmd,
64
+ stdout=subprocess.PIPE,
65
+ stderr=subprocess.STDOUT,
66
+ text=True,
67
+ bufsize=1
68
+ )
69
+
70
+ # Read output line by line
71
+ for line in process.stdout:
72
+ if "loss" in line.lower():
73
+ training_status["message"] = line.strip()
74
+ elif "epoch" in line.lower():
75
+ # Try to extract progress
76
+ try:
77
+ if "/" in line:
78
+ parts = line.split("/")
79
+ current = float(parts[0].split()[-1])
80
+ total = float(parts[1].split()[0])
81
+ training_status["progress"] = int((current / total) * 100)
82
+ except:
83
+ pass
84
+ elif "exact_match" in line.lower():
85
+ training_status["message"] = f"Evaluation: {line.strip()}"
86
+
87
+ process.wait()
88
+
89
+ if process.returncode == 0:
90
+ training_status["status"] = "completed"
91
+ training_status["message"] = "Training completed successfully! Model pushed to HuggingFace Hub."
92
+ training_status["progress"] = 100
93
+ else:
94
+ training_status["status"] = "error"
95
+ training_status["message"] = f"Training failed with exit code {process.returncode}"
96
+
97
+ except Exception as e:
98
+ training_status["status"] = "error"
99
+ training_status["message"] = f"Error: {str(e)}"
100
+
101
+ def train_model(hf_token, model_size, hub_username, num_epochs, use_balanced):
102
+ """Start training in background thread"""
103
+ global training_status
104
+
105
+ if not hf_token:
106
+ return "❌ Please provide HuggingFace token"
107
+
108
+ if training_status["status"] == "running":
109
+ return "⚠️ Training already in progress!"
110
+
111
+ # Start training in background thread
112
+ thread = threading.Thread(
113
+ target=run_training_subprocess,
114
+ args=(hf_token, model_size, hub_username, num_epochs, use_balanced)
115
+ )
116
+ thread.start()
117
+
118
+ return "🚀 Training started! Check status below..."
119
+
120
+ def get_training_status():
121
+ """Get current training status"""
122
+ global training_status
123
+
124
+ if training_status["status"] == "idle":
125
+ return "💤 No training in progress"
126
+ elif training_status["status"] == "running":
127
+ return f"""
128
+ 🏃 Training in progress... ({training_status['progress']}%)
129
+
130
+ {training_status['message']}
131
+ """
132
+ elif training_status["status"] == "completed":
133
+ return f"""
134
+ ✅ Training completed!
135
+
136
+ {training_status['message']}
137
+
138
+ Your model is available at: https://huggingface.co/{training_status.get('hub_username', 'your-username')}/resume-normalizer-flan-t5
139
+ """
140
+ else:
141
+ return f"""
142
+ ❌ Training failed!
143
+
144
+ {training_status['message']}
145
+ """
146
+
147
+ # Create Gradio interface
148
+ with gr.Blocks(title="Resume Normalizer Trainer") as app:
149
+ gr.Markdown("""
150
+ # Resume Normalizer Trainer
151
+
152
+ Train a Flan-T5 model to normalize company names, job titles, and skills from resumes.
153
+
154
+ **Features:**
155
+ - Company name normalization (e.g., "Google Inc" → "Alphabet Inc.")
156
+ - Job title standardization (e.g., "SWE" → "Software Engineer")
157
+ - Skills normalization (e.g., "JS" → "JavaScript")
158
+ - Binary equivalency detection
159
+
160
+ **Hardware:** Running on 4xL4 GPUs (96GB VRAM)
161
+ """)
162
+
163
+ with gr.Tab("📊 Check Data"):
164
+ check_btn = gr.Button("Check Available Datasets", variant="primary")
165
+ check_output = gr.Textbox(label="Dataset Status", lines=5)
166
+ check_btn.click(check_data, outputs=check_output)
167
+
168
+ with gr.Tab("🚀 Train Model"):
169
+ with gr.Row():
170
+ with gr.Column():
171
+ hf_token = gr.Textbox(
172
+ label="HuggingFace Token",
173
+ type="password",
174
+ placeholder="hf_...",
175
+ info="Required to push model to Hub"
176
+ )
177
+ hub_username = gr.Textbox(
178
+ label="HuggingFace Username",
179
+ value="aoisfhdugbos",
180
+ info="Your HuggingFace username"
181
+ )
182
+
183
+ with gr.Column():
184
+ model_size = gr.Dropdown(
185
+ label="Model Size",
186
+ choices=["T5-Base (250M)", "T5-Large (770M)"],
187
+ value="T5-Base (250M)",
188
+ info="Larger models are more accurate but slower"
189
+ )
190
+ num_epochs = gr.Slider(
191
+ label="Training Epochs",
192
+ minimum=1,
193
+ maximum=10,
194
+ value=5,
195
+ step=1,
196
+ info="More epochs = better quality but longer training"
197
+ )
198
+
199
+ use_balanced = gr.Checkbox(
200
+ label="Use Balanced Dataset (8,304 examples)",
201
+ value=False,
202
+ info="Check to use balanced dataset instead of full dataset (9,302 examples)"
203
+ )
204
+
205
+ train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
206
+ train_output = gr.Textbox(label="Training Output", lines=5)
207
+
208
+ train_btn.click(
209
+ train_model,
210
+ inputs=[hf_token, model_size, hub_username, num_epochs, use_balanced],
211
+ outputs=train_output
212
+ )
213
+
214
+ with gr.Tab("📈 Training Status"):
215
+ status_btn = gr.Button("🔄 Refresh Status", variant="secondary")
216
+ status_output = gr.Textbox(label="Current Status", lines=10)
217
+
218
+ # Auto-refresh status
219
+ status_timer = gr.Timer(5) # Refresh every 5 seconds
220
+ status_timer.tick(get_training_status, outputs=status_output)
221
+
222
+ status_btn.click(get_training_status, outputs=status_output)
223
+
224
+ with gr.Tab("ℹ️ About"):
225
+ gr.Markdown("""
226
+ ## Resume Normalizer Model
227
+
228
+ This trainer fine-tunes a Flan-T5 model for resume entity normalization tasks:
229
+
230
+ ### Supported Tasks:
231
+ 1. **Company Normalization**: Handles mergers, acquisitions, rebranding
232
+ 2. **Job Title Standardization**: Recognizes equivalent roles and seniority
233
+ 3. **Skills Normalization**: Standardizes technology names and abbreviations
234
+ 4. **Equivalency Detection**: Binary classification for entity matching
235
+
236
+ ### Model Architecture:
237
+ - Base Model: Google Flan-T5 (instruction-tuned)
238
+ - Fine-tuning: LoRA (Low-Rank Adaptation) for efficiency
239
+ - Multi-task: Uses task prefixes ([COMPANY], [JOB], [SKILLS])
240
+
241
+ ### Training Data:
242
+ - 478 manually curated high-quality examples
243
+ - 8,824 synthetic examples generated with GPT-4
244
+ - Total: 9,302 training examples
245
+
246
+ ### Expected Performance:
247
+ - Inference: <100ms per query
248
+ - Accuracy: >90% on test set
249
+ - Model size: 250M-770M parameters
250
+ """)
251
+
252
+ app.launch()
combined_balanced_training_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
combined_final_training_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.36.0
2
+ datasets>=2.16.0
3
+ accelerate>=0.25.0
4
+ peft>=0.7.0
5
+ evaluate>=0.4.1
6
+ rouge_score>=0.1.2
7
+ pandas>=2.0.0
8
+ numpy>=1.24.0
9
+ torch>=2.1.0
10
+ sentencepiece>=0.1.99
11
+ huggingface_hub>=0.20.0
12
+ gradio==5.6.0
13
+ tensorboard>=2.15.0
14
+ scikit-learn>=1.3.0
train.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ T5ForConditionalGeneration,
4
+ T5Tokenizer,
5
+ TrainingArguments,
6
+ Trainer,
7
+ DataCollatorForSeq2Seq,
8
+ EarlyStoppingCallback
9
+ )
10
+ from datasets import Dataset
11
+ import pandas as pd
12
+ import numpy as np
13
+ from accelerate import Accelerator
14
+ import os
15
+ from huggingface_hub import HfFolder
16
+ import logging
17
+
18
+ # Setup logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class ResumeNormalizationTrainer:
23
+ def __init__(self, model_name="google/flan-t5-base", use_lora=True):
24
+ self.model_name = model_name
25
+ self.use_lora = use_lora
26
+ self.accelerator = Accelerator()
27
+ self.device = self.accelerator.device
28
+
29
+ logger.info(f"Initializing model: {model_name}")
30
+ logger.info(f"Using LoRA: {use_lora}")
31
+ logger.info(f"Device: {self.device}")
32
+
33
+ # Load tokenizer and model
34
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
35
+ self.model = T5ForConditionalGeneration.from_pretrained(model_name)
36
+
37
+ # Setup LoRA if requested
38
+ if use_lora:
39
+ self._setup_lora()
40
+
41
+ def _setup_lora(self):
42
+ """Configure LoRA for efficient fine-tuning"""
43
+ try:
44
+ from peft import LoraConfig, get_peft_model, TaskType
45
+
46
+ lora_config = LoraConfig(
47
+ r=16, # rank
48
+ lora_alpha=32,
49
+ target_modules=["q", "v"], # T5 attention layers
50
+ lora_dropout=0.1,
51
+ bias="none",
52
+ task_type=TaskType.SEQ_2_SEQ_LM,
53
+ )
54
+
55
+ self.model = get_peft_model(self.model, lora_config)
56
+ self.model.print_trainable_parameters()
57
+ logger.info("LoRA configuration applied successfully")
58
+ except Exception as e:
59
+ logger.error(f"Failed to setup LoRA: {e}")
60
+ raise
61
+
62
+ def load_dataset(self, data_path):
63
+ """Load and prepare dataset"""
64
+ logger.info(f"Loading dataset from: {data_path}")
65
+ df = pd.read_csv(data_path)
66
+ logger.info(f"Loaded {len(df)} examples")
67
+
68
+ # Add task prefixes if not present
69
+ def add_task_prefix(row):
70
+ task = row['task_type']
71
+ instruction = row['instruction']
72
+
73
+ # Skip if already has prefix
74
+ if instruction.startswith('['):
75
+ return instruction
76
+
77
+ if task == 'normalize_company':
78
+ return f"[COMPANY] {instruction}"
79
+ elif task == 'normalize_job_title':
80
+ return f"[JOB] {instruction}"
81
+ elif task == 'normalize_skill':
82
+ return f"[SKILLS] {instruction}"
83
+ elif task == 'company_equivalence':
84
+ return f"[COMPANY] {instruction}"
85
+ elif task == 'job_title_equivalence':
86
+ return f"[JOB] {instruction}"
87
+ elif task == 'achievement_equivalence':
88
+ return f"[ACHIEVEMENT] {instruction}"
89
+ return instruction
90
+
91
+ df['instruction'] = df.apply(add_task_prefix, axis=1)
92
+
93
+ # Split into train/validation
94
+ train_size = int(0.9 * len(df))
95
+ train_df = df[:train_size]
96
+ val_df = df[train_size:]
97
+
98
+ logger.info(f"Train set: {len(train_df)} examples")
99
+ logger.info(f"Validation set: {len(val_df)} examples")
100
+
101
+ # Convert to HuggingFace Dataset
102
+ train_dataset = Dataset.from_pandas(train_df)
103
+ val_dataset = Dataset.from_pandas(val_df)
104
+
105
+ return train_dataset, val_dataset
106
+
107
+ def preprocess_function(self, examples):
108
+ """Tokenize inputs and targets"""
109
+ inputs = examples['instruction']
110
+ targets = examples['output']
111
+
112
+ # Tokenize inputs
113
+ model_inputs = self.tokenizer(
114
+ inputs,
115
+ max_length=256,
116
+ truncation=True,
117
+ padding="max_length"
118
+ )
119
+
120
+ # Tokenize targets
121
+ with self.tokenizer.as_target_tokenizer():
122
+ labels = self.tokenizer(
123
+ targets,
124
+ max_length=128,
125
+ truncation=True,
126
+ padding="max_length"
127
+ )
128
+
129
+ # Replace padding token id's of the labels by -100
130
+ labels["input_ids"] = [
131
+ [(l if l != self.tokenizer.pad_token_id else -100) for l in label]
132
+ for label in labels["input_ids"]
133
+ ]
134
+
135
+ model_inputs["labels"] = labels["input_ids"]
136
+ return model_inputs
137
+
138
+ def compute_metrics(self, eval_pred):
139
+ """Compute metrics for evaluation"""
140
+ predictions, labels = eval_pred
141
+
142
+ # Decode predictions
143
+ decoded_preds = self.tokenizer.batch_decode(
144
+ predictions, skip_special_tokens=True
145
+ )
146
+
147
+ # Replace -100 in the labels as we can't decode them
148
+ labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id)
149
+ decoded_labels = self.tokenizer.batch_decode(
150
+ labels, skip_special_tokens=True
151
+ )
152
+
153
+ # Calculate exact match accuracy
154
+ exact_match = sum(
155
+ pred.strip().lower() == label.strip().lower()
156
+ for pred, label in zip(decoded_preds, decoded_labels)
157
+ ) / len(decoded_preds)
158
+
159
+ logger.info(f"Exact match accuracy: {exact_match:.4f}")
160
+
161
+ return {"exact_match": exact_match}
162
+
163
+ def train(self, train_dataset, val_dataset, output_dir, hf_token=None, hub_username=None, num_epochs=5):
164
+ """Train the model"""
165
+ logger.info("Starting training preparation...")
166
+
167
+ # Tokenize datasets
168
+ train_dataset = train_dataset.map(
169
+ self.preprocess_function,
170
+ batched=True,
171
+ remove_columns=['instruction', 'output', 'task_type', 'quality_score']
172
+ )
173
+
174
+ val_dataset = val_dataset.map(
175
+ self.preprocess_function,
176
+ batched=True,
177
+ remove_columns=['instruction', 'output', 'task_type', 'quality_score']
178
+ )
179
+
180
+ # Data collator
181
+ data_collator = DataCollatorForSeq2Seq(
182
+ self.tokenizer,
183
+ model=self.model,
184
+ label_pad_token_id=-100,
185
+ pad_to_multiple_of=8
186
+ )
187
+
188
+ # Training arguments optimized for 4xL4 GPUs
189
+ training_args = TrainingArguments(
190
+ output_dir=output_dir,
191
+ num_train_epochs=num_epochs,
192
+ per_device_train_batch_size=32, # L4 has 24GB, can handle larger batches
193
+ per_device_eval_batch_size=64,
194
+ gradient_accumulation_steps=1,
195
+ gradient_checkpointing=True,
196
+ fp16=True, # Use mixed precision
197
+ optim="adamw_torch",
198
+ learning_rate=3e-4 if self.use_lora else 5e-5,
199
+ warmup_steps=500,
200
+ logging_steps=50,
201
+ evaluation_strategy="steps",
202
+ eval_steps=500,
203
+ save_strategy="steps",
204
+ save_steps=500,
205
+ load_best_model_at_end=True,
206
+ metric_for_best_model="exact_match",
207
+ greater_is_better=True,
208
+ push_to_hub=True if hf_token else False,
209
+ hub_model_id=f"{hub_username}/resume-normalizer-flan-t5" if hub_username else None,
210
+ hub_token=hf_token,
211
+ report_to=["tensorboard"],
212
+ ddp_find_unused_parameters=False,
213
+ dataloader_num_workers=4,
214
+ remove_unused_columns=False,
215
+ )
216
+
217
+ # Create trainer
218
+ trainer = Trainer(
219
+ model=self.model,
220
+ args=training_args,
221
+ train_dataset=train_dataset,
222
+ eval_dataset=val_dataset,
223
+ tokenizer=self.tokenizer,
224
+ data_collator=data_collator,
225
+ compute_metrics=self.compute_metrics,
226
+ callbacks=[
227
+ EarlyStoppingCallback(early_stopping_patience=3)
228
+ ],
229
+ )
230
+
231
+ logger.info("Starting training...")
232
+
233
+ # Train
234
+ trainer.train()
235
+
236
+ # Save model
237
+ logger.info("Saving model...")
238
+ if self.use_lora:
239
+ # Save LoRA adapter
240
+ self.model.save_pretrained(output_dir)
241
+ self.tokenizer.save_pretrained(output_dir)
242
+ else:
243
+ trainer.save_model(output_dir)
244
+
245
+ # Push to hub if token provided
246
+ if hf_token and hub_username:
247
+ logger.info("Pushing model to HuggingFace Hub...")
248
+ trainer.push_to_hub(
249
+ commit_message="Final model trained on resume normalization data"
250
+ )
251
+
252
+ logger.info("Training completed successfully!")
253
+ return trainer
254
+
255
+ def main():
256
+ """Main training function to be called from app.py"""
257
+ import argparse
258
+ parser = argparse.ArgumentParser()
259
+ parser.add_argument("--data_path", type=str, required=True)
260
+ parser.add_argument("--model_size", type=str, default="base")
261
+ parser.add_argument("--hf_token", type=str, default=None)
262
+ parser.add_argument("--hub_username", type=str, default=None)
263
+ parser.add_argument("--num_epochs", type=int, default=5)
264
+ parser.add_argument("--use_lora", action="store_true")
265
+ args = parser.parse_args()
266
+
267
+ # Set HF token if provided
268
+ if args.hf_token:
269
+ HfFolder.save_token(args.hf_token)
270
+
271
+ # Select model based on size
272
+ model_name = "google/flan-t5-base" if args.model_size == "base" else "google/flan-t5-large"
273
+
274
+ # Initialize trainer
275
+ trainer = ResumeNormalizationTrainer(
276
+ model_name=model_name,
277
+ use_lora=args.use_lora
278
+ )
279
+
280
+ # Load dataset
281
+ train_dataset, val_dataset = trainer.load_dataset(args.data_path)
282
+
283
+ # Train
284
+ output_dir = "./resume-normalizer-model"
285
+ trainer.train(
286
+ train_dataset=train_dataset,
287
+ val_dataset=val_dataset,
288
+ output_dir=output_dir,
289
+ hf_token=args.hf_token,
290
+ hub_username=args.hub_username,
291
+ num_epochs=args.num_epochs
292
+ )
293
+
294
+ print("Training completed successfully!")
295
+ print(f"Model saved to: {output_dir}")
296
+ if args.hf_token and args.hub_username:
297
+ print(f"Model available at: https://huggingface.co/{args.hub_username}/resume-normalizer-flan-t5")
298
+
299
+ if __name__ == "__main__":
300
+ main()