IR-Electra

Fine-tunes a pre-trained ELECTRA model for document relevance classification in Information Retrieval (IR).

Overview

This project fine-tunes ELECTRA on a synthetic dataset created for relevance classification of query-document pairs. The dataset covers three topics: sports, tech, and health, with labels indicating whether the document is relevant (1) or not (0) to the query.

Steps

  1. Generate Synthetic Dataset
    Create query-document pairs for three categories (sports, tech, health), with labels indicating relevance (1 for relevant, 0 for non-relevant).

  2. Preprocess Dataset
    Use a custom dataset class to tokenize the query-document pairs and format them for model input.

  3. Fine-Tune ELECTRA
    Fine-tune the ELECTRA model on the dataset to classify documents as relevant or non-relevant.

Topics

topics = {
    "sports": ["football", "basketball", "tennis", "match", "score"],
    "tech": ["AI", "machine learning", "cloud", "software", "algorithm"],
    "health": ["diet", "exercise", "nutrition", "wellness", "fitness"]
}


## Usage 

import torch
from transformers import ElectraForSequenceClassification, ElectraTokenizer

def predict_relevance(query, document, model, tokenizer, device=None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)

    encoding = tokenizer.encode_plus(
        query,
        document,
        max_length=256,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    logits = outputs.logits
    predicted_class = torch.argmax(logits, dim=1).item()
    return "Relevant" if predicted_class == 1 else "Non-relevant"

if __name__ == "__main__":
    MODEL_HUB_ID = "your-hf-username/your-model-repo-name"  # <-- replace this

    print("Loading model and tokenizer from Hugging Face Hub...")
    model = ElectraForSequenceClassification.from_pretrained(MODEL_HUB_ID)
    tokenizer = ElectraTokenizer.from_pretrained(MODEL_HUB_ID)

    query = "football match"
    document = "The football game last night was thrilling and intense."

    prediction = predict_relevance(query, document, model, tokenizer)
    print(f"Prediction: {prediction}")
Downloads last month
-
Safetensors
Model size
13.5M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support