IR-Electra
Fine-tunes a pre-trained ELECTRA model for document relevance classification in Information Retrieval (IR).
Overview
This project fine-tunes ELECTRA on a synthetic dataset created for relevance classification of query-document pairs. The dataset covers three topics: sports, tech, and health, with labels indicating whether the document is relevant (1) or not (0) to the query.
Steps
Generate Synthetic Dataset
Create query-document pairs for three categories (sports, tech, health), with labels indicating relevance (1 for relevant, 0 for non-relevant).Preprocess Dataset
Use a custom dataset class to tokenize the query-document pairs and format them for model input.Fine-Tune ELECTRA
Fine-tune the ELECTRA model on the dataset to classify documents as relevant or non-relevant.
Topics
topics = {
"sports": ["football", "basketball", "tennis", "match", "score"],
"tech": ["AI", "machine learning", "cloud", "software", "algorithm"],
"health": ["diet", "exercise", "nutrition", "wellness", "fitness"]
}
## Usage
import torch
from transformers import ElectraForSequenceClassification, ElectraTokenizer
def predict_relevance(query, document, model, tokenizer, device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)
encoding = tokenizer.encode_plus(
query,
document,
max_length=256,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()
return "Relevant" if predicted_class == 1 else "Non-relevant"
if __name__ == "__main__":
MODEL_HUB_ID = "your-hf-username/your-model-repo-name" # <-- replace this
print("Loading model and tokenizer from Hugging Face Hub...")
model = ElectraForSequenceClassification.from_pretrained(MODEL_HUB_ID)
tokenizer = ElectraTokenizer.from_pretrained(MODEL_HUB_ID)
query = "football match"
document = "The football game last night was thrilling and intense."
prediction = predict_relevance(query, document, model, tokenizer)
print(f"Prediction: {prediction}")
- Downloads last month
- -