Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import thirdai | |
| from thirdai import bolt, licensing | |
| import os | |
| import time | |
| thirdai.licensing.activate("1FF7B0-458ABC-5F382D-0A1513-904CF0-V3") | |
| max_posts = 5 | |
| df = pd.read_csv("processed_recipes_3.csv") | |
| model = bolt.UniversalDeepTransformer.load("1bn_name_ctg_keywords_4gram.bolt") | |
| recipe_id_to_row_num = {} | |
| for i in range(df.shape[0]): | |
| recipe_id_to_row_num[df.iloc[i,0]] = i | |
| INTRO_MARKDOWN = ( | |
| """# A billion parameter model, trained on a single CPU, in just 90 mins, on 522K recipes from food.com !! | |
| """ | |
| ) | |
| LIKE_TEXT = "π update LLM" | |
| FEEDBACK_RECEIVED_TEXT = "π Click search for updated results" | |
| SHOW_MORE = "Show more" | |
| SHOW_LESS = "Show less" | |
| def retrain(query, doc_id): | |
| query = query.lower() | |
| query.replace('\n', ' ') | |
| query = ' '.join([query[i:i+4] for i in range(len(query)-3)]) | |
| df = pd.DataFrame({ | |
| "Name": [query], | |
| "RecipeId": [str(doc_id)] | |
| }) | |
| filename = f"temptrain{hash(query)}{hash(doc_id)}{time.time()}.csv" | |
| df.to_csv(filename) | |
| prediction = None | |
| while prediction != doc_id: | |
| model.train(filename, epochs=1) | |
| prediction = model.predict( | |
| {"Name": query.replace('\n', ' ')}, | |
| return_predicted_class=True) | |
| os.remove(filename) | |
| # sample = {"query": query.replace('\n', ' '), "id": str(doc_id)} | |
| # batch = [sample] | |
| # prediction = None | |
| # while prediction != doc_id: | |
| # model.train_batch(batch, metrics=["categorical_accuracy"]) | |
| # prediction = model.predict(sample, return_predicted_class=True) | |
| def search(query): | |
| query = query.lower() | |
| query = ' '.join([query[i:i+4] for i in range(len(query)-3)]) | |
| scores = model.predict({"Name": query}) | |
| #### | |
| sorted_ids = scores.argsort()[-max_posts:][::-1] | |
| relevant_posts = [ | |
| df.iloc[pid] for pid in sorted_ids | |
| ] | |
| #### | |
| # K = min(2*max_posts, len(scores) - 1) | |
| # sorted_post_ids = scores.argsort()[-K:][::-1] | |
| # print(sorted_post_ids) | |
| # sorted_ids = [] | |
| # relevant_posts = [] | |
| # count = 0 | |
| # for pid in sorted_post_ids: | |
| # if pid in recipe_id_to_row_num: | |
| # relevant_posts.append(df.iloc[recipe_id_to_row_num[pid]]) | |
| # sorted_ids.append(pid) | |
| # count += 1 | |
| # if count==max_posts: | |
| # break | |
| #### | |
| header = [gr.Markdown.update(visible=True)] | |
| boxes = [ | |
| gr.Box.update(visible=True) | |
| for _ in relevant_posts | |
| ] | |
| titles = [ | |
| gr.Markdown.update(f"## {post['Name']}") | |
| for post in relevant_posts | |
| ] | |
| toggles = [ | |
| gr.Button.update( | |
| visible=True, | |
| value=SHOW_MORE, | |
| interactive=True, | |
| ) | |
| for _ in relevant_posts | |
| ] | |
| matches = [ | |
| gr.Button.update( | |
| value=LIKE_TEXT, | |
| interactive=True, | |
| ) | |
| for _ in relevant_posts | |
| ] | |
| bodies = [ | |
| gr.HTML.update( | |
| visible=False, | |
| value=f"<br/>" | |
| f"<h2>Description:</h2>\n{post['Description']}\n\n" | |
| "<hr class='solid'>" | |
| f"<h2>Ingredients:</h2>\n{post['RecipeIngredientParts']}\n\n" | |
| "<br/>" | |
| f"<h2>Instructions:</h2>\n{post['RecipeInstructions']}\n\n" | |
| "<br/>") | |
| for post in relevant_posts | |
| ] | |
| return ( | |
| header + | |
| boxes + | |
| titles + | |
| toggles + | |
| matches + | |
| bodies + | |
| [sorted_ids] | |
| ) | |
| def handle_toggle(toggle): | |
| if toggle == SHOW_MORE: | |
| new_toggle_text = SHOW_LESS | |
| visible = True | |
| if toggle == SHOW_LESS: | |
| new_toggle_text = SHOW_MORE | |
| visible = False | |
| return [ | |
| gr.Button.update(new_toggle_text), | |
| gr.HTML.update(visible=visible), | |
| ] | |
| def handle_feedback(button_id: int): | |
| def register_feedback(doc_ids, query): | |
| retrain( | |
| query=query, | |
| doc_id=doc_ids[button_id] | |
| ) | |
| return gr.Button.update( | |
| value=FEEDBACK_RECEIVED_TEXT, | |
| interactive=False, | |
| ) | |
| return register_feedback | |
| default_query = ( | |
| "biryani lamb spicy contains cloves and red chili powder, made with ghee and hard boiled eggs, made by grinding coconut and cashew" | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(INTRO_MARKDOWN) | |
| query = gr.Textbox(value=default_query, label="Query", lines=10) | |
| submit = gr.Button(value="Search") | |
| header = [gr.Markdown("# Relevant Recipes", visible=False)] | |
| post_boxes = [] | |
| post_titles = [] | |
| toggle_buttons = [] | |
| match_buttons = [] | |
| post_bodies = [] | |
| post_ids = gr.State([]) | |
| for i in range(max_posts): | |
| with gr.Box(visible=False) as box: | |
| post_boxes.append(box) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| title = gr.Markdown("") | |
| post_titles.append(title) | |
| with gr.Column(scale=1, min_width=370): | |
| with gr.Row(): | |
| with gr.Column(scale=3, min_width=170): | |
| toggle = gr.Button(SHOW_MORE) | |
| toggle_buttons.append(toggle) | |
| with gr.Column(scale=1, min_width=170): | |
| match = gr.Button(LIKE_TEXT) | |
| match.click( | |
| fn=handle_feedback(button_id=i), | |
| inputs=[post_ids, query], | |
| outputs=[match], | |
| ) | |
| match_buttons.append(match) | |
| body = gr.HTML("") | |
| post_bodies.append(body) | |
| toggle.click( | |
| fn=handle_toggle, | |
| inputs=[toggle], | |
| outputs=[toggle, body], | |
| ) | |
| allblocks = ( | |
| header + | |
| post_boxes + | |
| post_titles + | |
| toggle_buttons + | |
| match_buttons + | |
| post_bodies + | |
| [post_ids] | |
| ) | |
| query.submit( | |
| fn=search, | |
| inputs=[query], | |
| outputs=allblocks) | |
| submit.click( | |
| fn=search, | |
| inputs=[query], | |
| outputs=allblocks) | |
| demo.launch() |