ccdv/arxiv-summarization
Viewer • Updated • 432k • 8.2k • 126
How to use spolivin/bart-arxiv-lora with Transformers:
# Use a pipeline as a high-level helper
# Warning: Pipeline type "summarization" is no longer supported in transformers v5.
# You must load the model directly (see below) or downgrade to v4.x with:
# 'pip install "transformers<5.0.0'
from transformers import pipeline
pipe = pipeline("summarization", model="spolivin/bart-arxiv-lora") # Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("spolivin/bart-arxiv-lora", dtype="auto")This model is a fine-tuned LoRA version of facebook/bart-large-cnn model on ccdv/arxiv-summarization dataset.
The fine-tuning procedure can be viewed here. I also included a separate notebook for showcasing how the fine-tuned model vs base model perform on sample article texts from Arxiv.
Use the code below to get started with the model.
from peft import AutoPeftModelForSeq2SeqLM
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
lora_model = AutoPeftModelForSeq2SeqLM.from_pretrained("spolivin/bart-arxiv-lora")
or
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
base_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
lora_model = PeftModel.from_pretrained(base_model, "spolivin/bart-arxiv-lora")
After loading the model with adapters, one can easily use it for summarization tasks:
import torch
text = "Some sample article texts as a string"
# Tokenizing text
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=1024,
).to("cuda" if torch.cuda.is_available() else "cpu")
lora_model.to(inputs.input_ids.device)
# Generating summarized version
summary_ids = lora_model.generate(**inputs, max_length=250, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
Base model
facebook/bart-large-cnn