Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from datasets import load_dataset | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import pickle | |
| with open('book_embeddings.pkl', 'rb') as file: | |
| book_embeddings = pickle.load(file) | |
| model_checkpoint = 'intfloat/multilingual-e5-large' | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
| model = AutoModel.from_pretrained(model_checkpoint) | |
| books_data = load_dataset('vojtam/czech_books_descriptions', split="train+test") | |
| books_data.set_format('pandas') | |
| def average_pool(last_hidden_states, attention_mask): | |
| last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
| def create_embeddings(tokenizer, model, input_texts, batch_size=32): | |
| embeddings_list = [] | |
| for i in range(0, len(input_texts), batch_size): | |
| batch_texts = input_texts[i:i + batch_size] | |
| batch_dict = tokenizer(batch_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') | |
| # Get embeddings for batch | |
| with torch.no_grad(): | |
| outputs = model(**batch_dict) | |
| batch_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) | |
| batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) | |
| embeddings_list.append(batch_embeddings) | |
| if (i + batch_size) % (batch_size * 10) == 0: | |
| print(f"Processed {i + batch_size}/{len(input_texts)} texts") | |
| return torch.cat(embeddings_list, dim=0) | |
| def find_similar_books(query: str, n = 5): | |
| input_query = "query: " + query | |
| query_embedding = create_embeddings(tokenizer, model, input_query) | |
| scores = ((query_embedding @ book_embeddings.T) * 100).detach().numpy()[0] | |
| top_indices = np.argsort(scores)[-n:][::-1] | |
| return books_data[top_indices] | |
| css = """ | |
| .full-height-gallery { | |
| height: calc(100vh - 250px); | |
| overflow-y: auto; | |
| } | |
| #submit-btn { | |
| background-color: #ff5b00; | |
| color: #ffffff; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as intf: | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Popis knihy", info = "Zadejte popis knihy, kterou byste si chtěli přečíst a aplikace najde nejpodobněší knihy dle vašeho popisu", placeholder='Zadejte popis, například "drama z prostředí nemocnice"') | |
| n_books = gr.Number(value = 5, label = "Počet knih", info="Počet nejpodobnějších knih, které si přejete zobrazit", minimum = 1, step = 1) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Vyhledat knihy", elem_id="submit-btn") | |
| clear_btn = gr.Button("Smazat") | |
| with gr.Row(): | |
| dataframe = gr.Dataframe(label="Podobné knihy", show_label=False, elem_classes = ["full-height-gallery"]) | |
| submit_btn.click(fn=find_similar_books, inputs=[text_input, n_books], outputs=dataframe) | |
| clear_btn.click(fn=lambda: [None, []], inputs=None, outputs=[text_input, dataframe]) | |
| intf.launch(share=True) |