Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pickle | |
| import os | |
| from datasets import load_dataset | |
| from gradio.components import Label | |
| from InstructorEmbedding import INSTRUCTOR | |
| import heapq | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import nltk | |
| from nltk.corpus import stopwords | |
| from nltk.tokenize import word_tokenize, sent_tokenize | |
| from nltk.stem import WordNetLemmatizer | |
| import pandas as pd | |
| dataset = load_dataset("SandipPalit/Movie_Dataset") | |
| model = INSTRUCTOR('hkunlp/instructor-xl') | |
| def getSimilarity(sentences_a,sentences_b): | |
| embeddings_a = pickle.load(open(os.getcwd()+"/temp.pkl",'rb')) | |
| embeddings_b = model.encode(sentences_b) | |
| similarities = cosine_similarity(embeddings_a,embeddings_b) | |
| return similarities | |
| nltk.download('punkt') | |
| nltk.download('stopwords') | |
| nltk.download('wordnet') | |
| def preprocess(idx,text,total_length): | |
| sentences = sent_tokenize(text) | |
| stop_words = set(stopwords.words('english')) | |
| lemmatizer = WordNetLemmatizer() | |
| padding=''+'0'*(len(str(total_length))-len(str(idx))) | |
| output=[] | |
| for sentence in sentences: | |
| output.append(' '.join([lemmatizer.lemmatize(word) for word in sentence.split() if word not in stop_words])+'@'+padding+str(idx)) | |
| return output | |
| def get_pre_processed_data(size): | |
| sentences=[] | |
| for idx,x in enumerate(df['Plot'].head(size).tolist()): | |
| sentences.extend(preprocess(idx,x,df.shape[0])) | |
| return sentences | |
| #building_the_max_heap | |
| def heapsort(np_array,k): | |
| h=[] | |
| for idx,score in enumerate(np_array): | |
| heapq.heappush(h,(-score,idx)) #max_heap | |
| return h | |
| #return the id's of the movie | |
| def get_top_k_matches(np_array,k,sentences): | |
| indices=set() | |
| h=heapsort(np_array,k) | |
| visited=set() | |
| indices=[] | |
| while h and len(indices)!=k: | |
| score,idx=heapq.heappop(h) | |
| i=len(sentences[idx])-1 #based on the index find the sentence- reason for storing idx but not sentence | |
| count=1 | |
| number=0 | |
| while sentences[idx][i]!='@': #O(8-10 digits) i.e O(1) time | |
| number=number+count*int(sentences[idx][i]) | |
| count*=10 | |
| i-=1 | |
| if number not in visited: #duplicate ids are not added, mainting 2 arrays is to maintian the order | |
| indices.append(number) | |
| visited.add(number) | |
| return indices | |
| df=pd.DataFrame({"Title":dataset['train']['Title'],"Plot":dataset['train']['Overview']}) | |
| def getOutput(text, size=1000): | |
| sentences=get_pre_processed_data(int(size)) | |
| np_array=getSimilarity(sentences,[text]) | |
| output=[] | |
| for idx in get_top_k_matches(np_array,5,sentences): | |
| output.append("title = "+df.iloc[idx]['Title']+" "*5+" Plot = "+df.iloc[idx]['Plot']) | |
| return output | |
| iface = gr.Interface(fn=getOutput, | |
| inputs=[gr.inputs.Textbox(label="Text")], | |
| outputs=[Label() for i in range(5)], | |
| examples=[['After doing the list of experiments A mad scientist declares himself as the god'],["Three men fight for the girl's love"]] | |
| ) | |
| iface.launch(debug=True) | |