| | --- |
| | language: |
| | - en |
| | license: apache-2.0 |
| | library_name: transformers |
| | datasets: |
| | - ms_marco |
| | pipeline_tag: text2text-generation |
| | widget: |
| | - text: how to bake perfect cookie |
| | pipeline_tag: text2text-generation |
| | inference_config: |
| | generation_config: |
| | max_length: 35 |
| | num_beams: 1 |
| | do_sample: true |
| | repetition_penalty: 1.8 |
| | tags: |
| | - code |
| | --- |
| | |
| | ## Model Summary |
| | This is a generative model designed specifically for search query rewriting, employing a sequence-to-sequence architecture for generating reformulated queries. It leverages a Reinforcement Learning framework to further boost performance, integrating a policy gradient algorithm. The model is trained with reward functions aimed at diversifying the generated queries by paraphrasing keywords. It can be integrated with sparse retrieval methods, such as bm25-based retrieval, to enhance document recall in search. |
| |
|
| | ### Intended use cases |
| | Query rewriting for search (web, e-commerce), Virtual assistants and chatbots, Information retrieval |
| |
|
| | ### Model Description |
| |
|
| | Training Procedure |
| |
|
| | 1. The training process begins by initializing the sequence-to-sequence model with Google's [T5-base model ](https://huggingface.co/google-t5/t5-base). |
| | 2. Initially, the model undergoes supervised training using the [MS-MARCO query pairs dataset](https://github.com/Narabzad/msmarco-query-reformulation/tree/main/datasets/queries) |
| | 3. Subsequently, the model is fine-tuned using a reinforcement learning (RL) framework to enhance its ability to generate queries that are both diverse and relevant. |
| | 4. It uses a policy gradient approach to fine-tune the model. For a given input query, a set of trajectories (reformulated queries) are sampled from the model and reward is computed. Policy gradient algorithm is applied to update the model. |
| | 5. Rewards are heuristically computed to enhance the model's paraphrasing capability. However, these rewards can be substituted with other domain-specific or goal-specific reward functions as needed. |
| |
|
| | Refer [here](https://github.com/PraveenSH/RL-Query-Reformulation) for more details. |
| |
|
| |
|
| | ### Model Sources |
| |
|
| |
|
| | - **Repository:** https://github.com/PraveenSH/RL-Query-Reformulation |
| |
|
| |
|
| |
|
| | ### How to use |
| | For optimal utilization of this model, use sampling with repetition penalty to generate diverse samples. Below is the provided sample code. |
| | ```python |
| | import torch |
| | from transformers import T5ForConditionalGeneration, T5Tokenizer |
| | |
| | MODEL_ID = "prhegde/t5-query-reformulation-RL" |
| | |
| | tokenizer = T5Tokenizer.from_pretrained(MODEL_ID) |
| | model = T5ForConditionalGeneration.from_pretrained(MODEL_ID) |
| | model.eval() |
| | |
| | input_sequence = "how to bake great cookie" |
| | input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids |
| | print(f'Input: {input_sequence}') |
| | |
| | nsent = 4 |
| | with torch.no_grad(): |
| | for i in range(nsent): |
| | output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8) |
| | target_sequence = tokenizer.decode(output[0], skip_special_tokens=True) |
| | print(f'Target: {target_sequence}') |
| | ``` |