Spaces:
Sleeping
Sleeping
| import os | |
| import collections | |
| import string | |
| import re | |
| import numpy as np | |
| from datasets import load_dataset, load_metric | |
| from transformers import ( | |
| DebertaTokenizerFast, | |
| DebertaForQuestionAnswering, | |
| Trainer, | |
| TrainingArguments, | |
| default_data_collator, | |
| ) | |
| from peft import LoraConfig, get_peft_model | |
| from huggingface_hub import login | |
| # Load your HF token securely from environment variable | |
| hf_token = os.environ.get("roberta_token") | |
| if hf_token: | |
| login(token=hf_token) | |
| else: | |
| print("Warning: HF token not found in environment variable 'roberta_token'. Push to hub may fail.") | |
| metric = load_metric("squad") | |
| def normalize_answer(s): | |
| """Lower text and remove punctuation/articles/extra whitespace""" | |
| def remove_articles(text): | |
| return re.sub(r'\b(a|an|the)\b', ' ', text) | |
| def white_space_fix(text): | |
| return ' '.join(text.split()) | |
| def remove_punc(text): | |
| exclude = set(string.punctuation) | |
| return ''.join(ch for ch in text if ch not in exclude) | |
| def lower(text): | |
| return text.lower() | |
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
| def prepare_train_features(examples, tokenizer, max_length=512, doc_stride=128): | |
| tokenized_examples = tokenizer( | |
| examples["question"], | |
| examples["context"], | |
| truncation="only_second", | |
| max_length=max_length, | |
| stride=doc_stride, | |
| return_overflowing_tokens=True, | |
| return_offsets_mapping=True, | |
| padding="max_length", | |
| ) | |
| sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") | |
| offset_mapping = tokenized_examples.pop("offset_mapping") | |
| start_positions = [] | |
| end_positions = [] | |
| for i, offsets in enumerate(offset_mapping): | |
| input_ids = tokenized_examples["input_ids"][i] | |
| cls_index = input_ids.index(tokenizer.cls_token_id) | |
| sample_index = sample_mapping[i] | |
| answers = examples["answers"][sample_index] | |
| if len(answers["answer_start"]) == 0: | |
| start_positions.append(cls_index) | |
| end_positions.append(cls_index) | |
| else: | |
| start_char = answers["answer_start"][0] | |
| end_char = start_char + len(answers["text"][0]) | |
| sequence_ids = tokenized_examples.sequence_ids(i) | |
| token_start_index = 0 | |
| while sequence_ids[token_start_index] != 1: | |
| token_start_index += 1 | |
| token_end_index = len(input_ids) - 1 | |
| while sequence_ids[token_end_index] != 1: | |
| token_end_index -= 1 | |
| if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): | |
| start_positions.append(cls_index) | |
| end_positions.append(cls_index) | |
| else: | |
| while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: | |
| token_start_index += 1 | |
| start_positions.append(token_start_index - 1) | |
| while offsets[token_end_index][1] >= end_char: | |
| token_end_index -= 1 | |
| end_positions.append(token_end_index + 1) | |
| tokenized_examples["start_positions"] = start_positions | |
| tokenized_examples["end_positions"] = end_positions | |
| return tokenized_examples | |
| def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size=20, max_answer_length=30): | |
| all_start_logits, all_end_logits = raw_predictions | |
| example_id_to_index = {k: i for i, k in enumerate(examples["id"])} | |
| features_per_example = collections.defaultdict(list) | |
| for i, feature in enumerate(features): | |
| features_per_example[example_id_to_index[feature["example_id"]]].append(i) | |
| predictions = collections.OrderedDict() | |
| for example_index, example in enumerate(examples): | |
| feature_indices = features_per_example[example_index] | |
| min_null_score = None | |
| valid_answers = [] | |
| context = example["context"] | |
| for feature_index in feature_indices: | |
| start_logits = all_start_logits[feature_index] | |
| end_logits = all_end_logits[feature_index] | |
| offsets = features[feature_index]["offset_mapping"] | |
| cls_index = features[feature_index]["input_ids"].index(features[feature_index]["cls_token_id"]) | |
| feature_null_score = start_logits[cls_index] + end_logits[cls_index] | |
| if min_null_score is None or min_null_score > feature_null_score: | |
| min_null_score = feature_null_score | |
| start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist() | |
| end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist() | |
| for start_index in start_indexes: | |
| for end_index in end_indexes: | |
| if ( | |
| start_index >= len(offsets) | |
| or end_index >= len(offsets) | |
| or offsets[start_index] is None | |
| or offsets[end_index] is None | |
| ): | |
| continue | |
| if end_index < start_index or end_index - start_index + 1 > max_answer_length: | |
| continue | |
| start_char = offsets[start_index][0] | |
| end_char = offsets[end_index][1] | |
| valid_answers.append( | |
| {"score": start_logits[start_index] + end_logits[end_index], "text": context[start_char:end_char]} | |
| ) | |
| best_answer = max(valid_answers, key=lambda x: x["score"]) if valid_answers else {"text": "", "score": 0.0} | |
| predictions[example["id"]] = best_answer["text"] | |
| return predictions | |
| def compute_metrics(p, tokenizer, examples, features): | |
| predictions = postprocess_qa_predictions(examples, features, p.predictions) | |
| formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] | |
| references = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples] | |
| return metric.compute(predictions=formatted_predictions, references=references) | |
| def main(): | |
| model_name = "microsoft/deberta-xlarge" | |
| output_dir = "./deberta-lora-cuad-finetuned" | |
| datasets = load_dataset("theatticusproject/cuad-qa") | |
| tokenizer = DebertaTokenizerFast.from_pretrained(model_name) | |
| model = DebertaForQuestionAnswering.from_pretrained(model_name) | |
| # LoRA config: tune rank and dropout as needed | |
| lora_config = LoraConfig( | |
| r=8, | |
| lora_alpha=32, | |
| target_modules=["query", "value"], # Adjust for DeBERTa internals as needed | |
| lora_dropout=0.1, | |
| bias="none", | |
| task_type="QUESTION_ANSWERING" | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| train_dataset = datasets["train"].map( | |
| lambda examples: prepare_train_features(examples, tokenizer), | |
| batched=True, | |
| remove_columns=datasets["train"].column_names, | |
| ) | |
| val_dataset = datasets["validation"].map( | |
| lambda examples: prepare_train_features(examples, tokenizer), | |
| batched=True, | |
| remove_columns=datasets["validation"].column_names, | |
| ) | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| evaluation_strategy="steps", | |
| eval_steps=500, | |
| save_steps=500, | |
| save_total_limit=2, | |
| learning_rate=3e-4, # LoRA usually supports higher LR | |
| per_device_train_batch_size=1, | |
| per_device_eval_batch_size=1, | |
| num_train_epochs=3, | |
| weight_decay=0.0, | |
| logging_dir=f"{output_dir}/logs", | |
| logging_steps=100, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_f1", | |
| greater_is_better=True, | |
| fp16=True, | |
| push_to_hub=True, | |
| hub_model_id="AvocadoMuffin/deberta_finetuned_qa_lora", | |
| hub_strategy="checkpoint", | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| tokenizer=tokenizer, | |
| data_collator=default_data_collator, | |
| compute_metrics=lambda p: compute_metrics(p, tokenizer, datasets["validation"], val_dataset), | |
| ) | |
| trainer.train() | |
| trainer.push_to_hub() | |
| if __name__ == "__main__": | |
| main() | |