Spaces:
Runtime error
Runtime error
| """ Semantically picks most similar tags in training space of | |
| generative model """ | |
| # IMPORTS | |
| import json | |
| import torch | |
| from typing import List | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import streamlit as st | |
| # FUNCTIONS | |
| # create embeddings | |
| def get_embeddings(text: str, token_length: int, tokenizer, model): | |
| tokens = tokenizer(text, max_length=token_length, | |
| padding='max_length', truncation=True) | |
| output = model(torch.tensor(tokens.input_ids).unsqueeze(0), | |
| attention_mask=torch.tensor( | |
| tokens.attention_mask | |
| ).unsqueeze(0)).hidden_states[-1] | |
| return torch.mean(output, axis=1).detach().numpy() | |
| # get doc with highest similarity to query | |
| def nearest_doc(doc_list: List[str], | |
| query: str, | |
| tokenizer, | |
| model, | |
| token_length: int = 10): | |
| # if query is already in doc list, return query | |
| if query in doc_list: | |
| return query | |
| # get embeddings for each document | |
| outs = [ | |
| get_embeddings(doc, token_length, tokenizer, model) for doc in doc_list | |
| ] | |
| # get embeddings for query | |
| query_embeddings = get_embeddings(query, token_length, tokenizer, model) | |
| # get similarity of each document embedding to query embedding | |
| sims = [cosine_similarity(out, query_embeddings)[0][0] for out in outs] | |
| return max(zip(sims, doc_list))[1] | |
| # MAIN | |
| def get_nearest_tags(user_tags: List[str]): | |
| st.write("function called") | |
| # download pretrained model | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased",) | |
| model = AutoModel.from_pretrained("bert-base-uncased", | |
| output_hidden_states=True) | |
| st.write("model downloaded") | |
| # get tag lists from local json file | |
| with open("./nlp/tags.json", "r") as jf: | |
| tags = json.load(jf) | |
| st.write("json opened") | |
| # separate tags by type | |
| user_genre, user_mood, user_instr = user_tags | |
| genres, moods, instrs = tags["genre"], tags["mood"], tags["instrument"] | |
| st.write("waiting on return") | |
| return ( | |
| nearest_doc(genres, user_genre, tokenizer, model), | |
| nearest_doc(moods, user_mood, tokenizer, model), | |
| nearest_doc(instrs, user_instr, tokenizer, model) | |
| ) | |