YAML Metadata Warning: empty or missing yaml metadata in repo card

Check out the documentation for more information.

Uses

You can run example code on colab

  1. You should create ModernBertForQueryComparison class first
import torch
from transformers import AutoTokenizer, ModernBertPreTrainedModel, ModernBertModel
from torch import nn

# 定義模型類別(與訓練時相同)
class ModernBertForQueryComparison(ModernBertPreTrainedModel):
    """
    繼承 ModernBertPreTrainedModel,可以使用 from_pretrained
    """
    def __init__(self, config):
        super().__init__(config)
        self.bert = ModernBertModel(config)
        self.dropout = nn.Dropout(config.mlp_dropout if hasattr(config, 'mlp_dropout') else 0.1)
        self.score_predictor = nn.Linear(config.hidden_size, 1)
        
        # 初始化新增層
        self.post_init()
    
    def forward(self,
                article_input_ids=None,
                article_attention_mask=None,
                sentence_input_ids=None,
                sentence_attention_mask=None,
                labels=None):
        """
        return_dict=True 時, ModernBertModel 預設會回傳 BaseModelOutput,
        這裡從 last_hidden_state 取 [CLS] 位置上的向量,用來做分數預測。
        """
        # === 文章查詢 ===
        article_outputs = self.bert(
            input_ids=article_input_ids,
            attention_mask=article_attention_mask,
            return_dict=True
        )
        # [batch, seq_len, hidden_dim]
        article_cls = article_outputs.last_hidden_state[:, 0, :]
        article_cls = self.dropout(article_cls)
        article_score = self.score_predictor(article_cls)  # [batch, 1]
        
        # === 句子查詢 ===
        sentence_outputs = self.bert(
            input_ids=sentence_input_ids,
            attention_mask=sentence_attention_mask,
            return_dict=True
        )
        sentence_cls = sentence_outputs.last_hidden_state[:, 0, :]
        sentence_cls = self.dropout(sentence_cls)
        sentence_score = self.score_predictor(sentence_cls)
        
        # relative_score = sigmoid(sentence) - sigmoid(article)
        # 但這裡我們可以先不做 sigmoid, 用 BCEWithLogitsLoss 會更直接
        relative_score = sentence_score - article_score  # [batch, 1]
        
        loss = None
        if labels is not None:
            # 這裡 labels 為 0 or 1;我們用 BCE with logits:
            # predict = relative_score => >0表示 sentence較好, <0表示 article 較好
            # 因此把相對分數丟進 BCEWithLogitsLoss, label=1 => sentence好
            # label=0 => article好
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(relative_score.view(-1), labels.view(-1))

        # 回傳
        # 可回傳 (loss, relative_score, article_score, sentence_score) 或 dict
        return {
            'loss': loss,
            'relative_score': relative_score,
            'article_score': article_score,
            'sentence_score': sentence_score,
        }
  1. then create predict_better_query function
# 預測函數
def predict_better_query(model, tokenizer, query, article_query, sentence_query, device, max_length=8192):
    model.eval()
    
    article_encoding = tokenizer(
        query,
        article_query,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    sentence_encoding = tokenizer(
        query,
        sentence_query,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    
    article_input_ids = article_encoding['input_ids'].to(device)
    article_attention_mask = article_encoding['attention_mask'].to(device)
    sentence_input_ids = sentence_encoding['input_ids'].to(device)
    sentence_attention_mask = sentence_encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(
            article_input_ids=article_input_ids,
            article_attention_mask=article_attention_mask,
            sentence_input_ids=sentence_input_ids,
            sentence_attention_mask=sentence_attention_mask,
            labels=None
        )
    
    relative_score = outputs['relative_score'].item()
    article_score = outputs['article_score'].item()
    sentence_score = outputs['sentence_score'].item()
    
    # relative_score > 0 => sentence_query 更好
    is_sentence_better = relative_score > 0
    result = {
        'is_sentence_better': is_sentence_better,
        'relative_score': relative_score,
        'article_score': article_score,
        'sentence_score': sentence_score
    }
    return result
  1. inference
def main():
    model_dir = "CheWei/ModernBERT_16x2_1e-5_8192"  # 直接指定模型目錄
    
    # 設置設備
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # 載入預先訓練的模型和分詞器
    print(f"Loading model from {model_dir}...")
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = ModernBertForQueryComparison.from_pretrained(model_dir)
    model.to(device)
    print("Model loaded successfully!")
    
    # 定義測試例子
    test_examples = [
        {
            "query": "Python programming tutorials",
            "article_query": "Python programming guides and examples",
            "sentence_query": "How to learn Python programming"
        },
        {
            "query": "Healthy breakfast ideas",
            "article_query": "Nutritious breakfast recipes for busy mornings",
            "sentence_query": "Quick and healthy breakfast options"
        },
        {
            "query": "史丹佛大學課程",
            "article_query": "史丹佛大學開設的各類專業課程介紹與選擇指南",
            "sentence_query": "史丹佛大學有哪些熱門課程可以選擇"
        }
    ]
    
    # 遍歷並處理每個例子
    print("\n=== Running Test Examples ===")
    
    for i, example in enumerate(test_examples):
        print(f"\nExample {i+1}:")
        print(f"Query: {example['query']}")
        print(f"Article query: {example['article_query']}")
        print(f"Sentence query: {example['sentence_query']}")
        
        # 進行預測
        result = predict_better_query(
            model, 
            tokenizer, 
            example['query'], 
            example['article_query'], 
            example['sentence_query'], 
            device
        )
        
        # 顯示結果
        print("--- Results ---")
        print(f"Article query score: {result['article_score']:.4f}")
        print(f"Sentence query score: {result['sentence_score']:.4f}")
        print(f"Relative score: {result['relative_score']:.4f}")
        
        if result['is_sentence_better']:
            print("Conclusion: ✓ Sentence query is better")
        else:
            print("Conclusion: ✓ Article query is better")
        print("-" * 50)

if __name__ == "__main__":
    main()
Downloads last month
1
Safetensors
Model size
0.4B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support