File size: 1,804 Bytes
e0b6877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

class QueryDecomposer:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
        self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
        self.model.eval()

    def decompose(self, query):
        sub_queries = []
        
        prompts = [
            f"What is the definition and architecture of the main concept in: {query}",
            f"What are the key advantages and improvements of: {query}",
            f"What experiments or benchmarks compare the methods in: {query}",
        ]
        
        for prompt in prompts:
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=64,
                    num_beams=2,
                    early_stopping=True
                )
            text = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
            if len(text) > 5 and text.lower() != query.lower():
                sub_queries.append(text)

        if len(sub_queries) == 0:
            return self._rule_based_decompose(query)

        return sub_queries[:3]

    def _rule_based_decompose(self, query):
        stop_words = {"is", "are", "was", "were", "the", "a", "an", "why",
                      "how", "what", "which", "does", "do", "than"}

        words = [w for w in query.lower().split() if w not in stop_words]

        sub_queries = [
            query,
            " ".join(words[:len(words)//2 + 1]),
            " ".join(words[len(words)//2:]),
        ]

        return [q for q in sub_queries if len(q) > 5][:3]