File size: 2,684 Bytes
f7c34c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
from transformers import ErnieModel, ErnieTokenizer

class ErnieBotDeepSearch(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = "ErnieBot Deep Search"
        self.version = "Original 1.0"
        
        # Core Components
        self.ernie = ErnieModel.from_pretrained("ernie-3.0-base-zh")
        self.tokenizer = ErnieTokenizer.from_pretrained("ernie-3.0-base-zh")
        
        # Deep Search Components
        self.search_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=768, nhead=12)
            for _ in range(6)
        ])
        
        self.knowledge_encoder = nn.Linear(768, 1024)
        self.cross_attention = nn.MultiheadAttention(1024, 16)
        
        # Output layers
        self.classifier = nn.Linear(1024, 2)
        self.ranking_head = nn.Linear(1024, 1)

    def deep_search(self, query, documents):
        # Encode query
        query_tokens = self.tokenizer(query, return_tensors="pt")
        query_embed = self.ernie(**query_tokens)[0]
        
        # Process documents
        doc_embeddings = []
        for doc in documents:
            doc_tokens = self.tokenizer(doc, return_tensors="pt")
            doc_embed = self.ernie(**doc_tokens)[0]
            doc_embeddings.append(doc_embed)
        
        # Deep search processing
        search_results = self._process_deep_search(query_embed, doc_embeddings)
        return self._rank_results(search_results)

    def _process_deep_search(self, query, documents):
        query_enhanced = self.knowledge_encoder(query)
        
        results = []
        for doc in documents:
            # Apply search layers
            for layer in self.search_layers:
                doc = layer(doc)
            
            # Cross-attention between query and document
            doc_enhanced = self.knowledge_encoder(doc)
            attention_output, _ = self.cross_attention(
                query_enhanced, doc_enhanced, doc_enhanced
            )
            
            results.append(attention_output)
        return results

    def _rank_results(self, search_results):
        rankings = []
        for result in search_results:
            score = self.ranking_head(result)
            rankings.append(score)
        return torch.stack(rankings).squeeze()

    def train_step(self, batch):
        query, positive_docs, negative_docs = batch
        pos_scores = self.deep_search(query, positive_docs)
        neg_scores = self.deep_search(query, negative_docs)
        
        loss = nn.MarginRankingLoss(margin=1.0)(pos_scores, neg_scores, torch.ones_like(pos_scores))
        return loss