Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
-
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch import Tensor
|
| 5 |
from transformers import AutoTokenizer, AutoModel
|
|
|
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import os
|
| 8 |
|
|
@@ -33,6 +35,12 @@ tasks = {
|
|
| 33 |
'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query',
|
| 34 |
}
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct')
|
| 37 |
model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct', torch_dtype=torch.float16, device_map=device)
|
| 38 |
|
|
@@ -56,15 +64,33 @@ def load_corpus_from_json(file_path):
|
|
| 56 |
with open(file_path, 'r') as file:
|
| 57 |
data = json.load(file)
|
| 58 |
return data
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def compute_embeddings(selected_task, input_text):
|
| 62 |
try:
|
| 63 |
task_description = tasks[selected_task]
|
| 64 |
except KeyError:
|
| 65 |
print(f"Selected task not found: {selected_task}")
|
| 66 |
return f"Error: Task '{selected_task}' not found. Please select a valid task."
|
| 67 |
-
max_length =
|
| 68 |
processed_texts = [f'Instruct: {task_description}\nQuery: {input_text}']
|
| 69 |
|
| 70 |
batch_dict = tokenizer(processed_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
|
|
@@ -75,9 +101,20 @@ def compute_embeddings(selected_task, input_text):
|
|
| 75 |
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
| 76 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 77 |
embeddings_list = embeddings.detach().cpu().numpy().tolist()
|
|
|
|
| 78 |
return embeddings_list
|
| 79 |
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def compute_similarity(selected_task, sentence1, sentence2, extra_sentence1, extra_sentence2):
|
| 82 |
try:
|
| 83 |
task_description = tasks[selected_task]
|
|
@@ -105,17 +142,20 @@ def compute_similarity(selected_task, sentence1, sentence2, extra_sentence1, ext
|
|
| 105 |
free_memory(embeddings1, embeddings2, embeddings3, embeddings4)
|
| 106 |
|
| 107 |
similarity_scores = {"Similarity 1-2": similarity1, "Similarity 1-3": similarity2, "Similarity 1-4": similarity3}
|
|
|
|
| 108 |
return similarity_scores
|
| 109 |
|
| 110 |
-
|
| 111 |
def compute_cosine_similarity(emb1, emb2):
|
| 112 |
tensor1 = torch.tensor(emb1).to(device).half()
|
| 113 |
tensor2 = torch.tensor(emb2).to(device).half()
|
| 114 |
similarity = F.cosine_similarity(tensor1, tensor2).item()
|
| 115 |
free_memory(tensor1, tensor2)
|
|
|
|
| 116 |
return similarity
|
| 117 |
|
| 118 |
|
|
|
|
| 119 |
def compute_embeddings_batch(input_texts):
|
| 120 |
max_length = 2042
|
| 121 |
processed_texts = [f'Instruct: {task_description}\nQuery: {text}' for text in input_texts]
|
|
@@ -127,6 +167,7 @@ def compute_embeddings_batch(input_texts):
|
|
| 127 |
outputs = model(**batch_dict)
|
| 128 |
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
| 129 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
|
|
| 130 |
return embeddings.detach().cpu().numpy()
|
| 131 |
|
| 132 |
def semantic_search(query_embedding, corpus_embeddings, top_k=5):
|
|
@@ -140,6 +181,31 @@ def search_similar_sentences(input_question, corpus_sentences, corpus_embeddings
|
|
| 140 |
results = [(corpus_sentences[i], top_k_scores[i]) for i in top_k_indices]
|
| 141 |
return results
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
def app_interface():
|
| 145 |
corpus_sentences = []
|
|
@@ -210,6 +276,31 @@ def app_interface():
|
|
| 210 |
outputs=search_results_output
|
| 211 |
)
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
with gr.Row():
|
| 214 |
with gr.Column():
|
| 215 |
input_text_box
|
|
@@ -219,5 +310,5 @@ def app_interface():
|
|
| 219 |
|
| 220 |
return demo
|
| 221 |
|
| 222 |
-
|
| 223 |
app_interface().launch()
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch import Tensor
|
| 5 |
from transformers import AutoTokenizer, AutoModel
|
| 6 |
+
import threading
|
| 7 |
+
import queue
|
| 8 |
import gradio as gr
|
| 9 |
import os
|
| 10 |
|
|
|
|
| 35 |
'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query',
|
| 36 |
}
|
| 37 |
|
| 38 |
+
|
| 39 |
+
# Global queue for embedding requests
|
| 40 |
+
embedding_request_queue = queue.Queue()
|
| 41 |
+
embedding_response_queue = queue.Queue()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct')
|
| 45 |
model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct', torch_dtype=torch.float16, device_map=device)
|
| 46 |
|
|
|
|
| 64 |
with open(file_path, 'r') as file:
|
| 65 |
data = json.load(file)
|
| 66 |
return data
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def embedding_worker():
|
| 70 |
+
while True:
|
| 71 |
+
# Wait for an item in the queue
|
| 72 |
+
item = embedding_request_queue.get()
|
| 73 |
+
if item is None:
|
| 74 |
+
break
|
| 75 |
+
selected_task, input_text = item
|
| 76 |
+
embeddings = compute_embeddings(selected_task, input_text)
|
| 77 |
+
formatted_response = format_response(embeddings)
|
| 78 |
+
|
| 79 |
+
embedding_response_queue.put(formatted_response)
|
| 80 |
+
embedding_request_queue.task_done()
|
| 81 |
+
clear_cuda_cache()
|
| 82 |
+
|
| 83 |
+
threading.Thread(target=embedding_worker, daemon=True).start()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@spaces.GPU
|
| 87 |
def compute_embeddings(selected_task, input_text):
|
| 88 |
try:
|
| 89 |
task_description = tasks[selected_task]
|
| 90 |
except KeyError:
|
| 91 |
print(f"Selected task not found: {selected_task}")
|
| 92 |
return f"Error: Task '{selected_task}' not found. Please select a valid task."
|
| 93 |
+
max_length = 2048
|
| 94 |
processed_texts = [f'Instruct: {task_description}\nQuery: {input_text}']
|
| 95 |
|
| 96 |
batch_dict = tokenizer(processed_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
|
|
|
|
| 101 |
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
| 102 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 103 |
embeddings_list = embeddings.detach().cpu().numpy().tolist()
|
| 104 |
+
clear_cuda_cache()
|
| 105 |
return embeddings_list
|
| 106 |
|
| 107 |
+
@spaces.GPU
|
| 108 |
+
def decode_embedding(embedding_str):
|
| 109 |
+
try:
|
| 110 |
+
embedding = [float(num) for num in embedding_str.split(',')]
|
| 111 |
+
embedding_tensor = torch.tensor(embedding, dtype=torch.float16, device=device)
|
| 112 |
+
decoded_embedding = tokenizer.decode(embedding_tensor[0], skip_special_tokens=True)
|
| 113 |
+
return decoded_embedding.cpu().numpy().tolist()
|
| 114 |
+
except Exception as e:
|
| 115 |
+
return f"Error in decoding: {str(e)}"
|
| 116 |
+
|
| 117 |
+
@spaces.GPU
|
| 118 |
def compute_similarity(selected_task, sentence1, sentence2, extra_sentence1, extra_sentence2):
|
| 119 |
try:
|
| 120 |
task_description = tasks[selected_task]
|
|
|
|
| 142 |
free_memory(embeddings1, embeddings2, embeddings3, embeddings4)
|
| 143 |
|
| 144 |
similarity_scores = {"Similarity 1-2": similarity1, "Similarity 1-3": similarity2, "Similarity 1-4": similarity3}
|
| 145 |
+
clear_cuda_cache()
|
| 146 |
return similarity_scores
|
| 147 |
|
| 148 |
+
@spaces.GPU
|
| 149 |
def compute_cosine_similarity(emb1, emb2):
|
| 150 |
tensor1 = torch.tensor(emb1).to(device).half()
|
| 151 |
tensor2 = torch.tensor(emb2).to(device).half()
|
| 152 |
similarity = F.cosine_similarity(tensor1, tensor2).item()
|
| 153 |
free_memory(tensor1, tensor2)
|
| 154 |
+
clear_cuda_cache()
|
| 155 |
return similarity
|
| 156 |
|
| 157 |
|
| 158 |
+
@spaces.GPU
|
| 159 |
def compute_embeddings_batch(input_texts):
|
| 160 |
max_length = 2042
|
| 161 |
processed_texts = [f'Instruct: {task_description}\nQuery: {text}' for text in input_texts]
|
|
|
|
| 167 |
outputs = model(**batch_dict)
|
| 168 |
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
| 169 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 170 |
+
clear_cuda_cache()
|
| 171 |
return embeddings.detach().cpu().numpy()
|
| 172 |
|
| 173 |
def semantic_search(query_embedding, corpus_embeddings, top_k=5):
|
|
|
|
| 181 |
results = [(corpus_sentences[i], top_k_scores[i]) for i in top_k_indices]
|
| 182 |
return results
|
| 183 |
|
| 184 |
+
# openai response object formatting
|
| 185 |
+
def format_response(embeddings):
|
| 186 |
+
return {
|
| 187 |
+
"data": [
|
| 188 |
+
{
|
| 189 |
+
"embedding": embeddings,
|
| 190 |
+
"index": 0,
|
| 191 |
+
"object": "embedding"
|
| 192 |
+
}
|
| 193 |
+
],
|
| 194 |
+
"model": "e5-mistral",
|
| 195 |
+
"object": "list",
|
| 196 |
+
"usage": {
|
| 197 |
+
"prompt_tokens": 17,
|
| 198 |
+
"total_tokens": 17
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
def generate_and_format_embeddings(selected_task, input_text):
|
| 203 |
+
embedding_request_queue.put((selected_task, input_text))
|
| 204 |
+
response = embedding_response_queue.get()
|
| 205 |
+
embedding_response_queue.task_done()
|
| 206 |
+
clear_cuda_cache()
|
| 207 |
+
return response
|
| 208 |
+
|
| 209 |
|
| 210 |
def app_interface():
|
| 211 |
corpus_sentences = []
|
|
|
|
| 276 |
outputs=search_results_output
|
| 277 |
)
|
| 278 |
|
| 279 |
+
with gr.Tab("Connector-like Embeddings"):
|
| 280 |
+
with gr.Row():
|
| 281 |
+
input_text_box_connector = gr.Textbox(label="Input Text", placeholder="Enter text or array of texts")
|
| 282 |
+
model_dropdown_connector = gr.Dropdown(label="Model", choices=["ArguAna", "ClimateFEVER", "DBPedia", "FEVER", "FiQA2018", "HotpotQA", "MSMARCO", "NFCorpus", "NQ", "QuoraRetrieval", "SCIDOCS", "SciFact", "Touche2020", "TRECCOVID"], value="text-embedding-ada-002")
|
| 283 |
+
encoding_format_connector = gr.Radio(label="Encoding Format", choices=["float", "base64"], value="float")
|
| 284 |
+
user_connector = gr.Textbox(label="User", placeholder="Enter user identifier (optional)")
|
| 285 |
+
submit_button_connector = gr.Button("Generate Embeddings")
|
| 286 |
+
output_display_connector = gr.JSON(label="Embeddings Output")
|
| 287 |
+
submit_button_connector.click(
|
| 288 |
+
fn=generate_and_format_embeddings,
|
| 289 |
+
inputs=[model_dropdown_connector, input_text_box_connector],
|
| 290 |
+
outputs=output_display_connector
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# with gr.Tab("Decode Embedding"):
|
| 294 |
+
# embedding_input = gr.Textbox(label="Enter Embedding (comma-separated floats)")
|
| 295 |
+
# decode_button = gr.Button("Decode")
|
| 296 |
+
# decoded_output = gr.Textbox(label="Decoded Embedding")
|
| 297 |
+
#
|
| 298 |
+
# decode_button.click(
|
| 299 |
+
# fn=decode_embedding,
|
| 300 |
+
# inputs=embedding_input,
|
| 301 |
+
# outputs=decoded_output
|
| 302 |
+
# )
|
| 303 |
+
|
| 304 |
with gr.Row():
|
| 305 |
with gr.Column():
|
| 306 |
input_text_box
|
|
|
|
| 310 |
|
| 311 |
return demo
|
| 312 |
|
| 313 |
+
app_interface().queue()
|
| 314 |
app_interface().launch()
|