Spaces:
Sleeping
Sleeping
| import os | |
| import socket | |
| import pandas as pd | |
| import gradio as gr | |
| # import spaces #[uncomment to use ZeroGPU] | |
| # from diffusers import DiffusionPipeline | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if torch.cuda.is_available(): | |
| torch_dtype = torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| from dataset import category2human, create_prompt | |
| LOCAL_COMPUTER_NAMES = ["amir-xps"] | |
| def is_local_machine(): | |
| return socket.gethostname().lower() in [ | |
| name.lower() for name in LOCAL_COMPUTER_NAMES | |
| ] | |
| if is_local_machine(): | |
| model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv2") | |
| else: | |
| model_path = "Hacker1337/distilbert-arxiv-checkpoint" | |
| from transformers import pipeline | |
| classifier = pipeline( | |
| "text-classification", | |
| model=model_path, | |
| tokenizer=model_path, | |
| ) | |
| # @spaces.GPU #[uncomment to use ZeroGPU] | |
| def infer( | |
| title_prompt, | |
| summary_prompt, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| sample_prompt_full = create_prompt( | |
| title_prompt, | |
| summary_prompt, | |
| ) | |
| predictions = classifier(sample_prompt_full, top_k=None) | |
| target_probs_sum = 0.95 | |
| print(predictions) | |
| df = pd.DataFrame(predictions) | |
| df["label"] = df["label"].apply(lambda x: category2human[x]) | |
| label_dict = {} | |
| bar_plot_dict = {} | |
| total_prop = sum([prediction["score"] for prediction in predictions]) | |
| gained_prob = 0 | |
| for prediction in sorted(predictions, key=lambda x: x["score"], reverse=True): | |
| bar_plot_dict[prediction["label"]] = prediction["score"] | |
| if (gained_prob) / total_prop < target_probs_sum: | |
| label_dict[category2human[prediction["label"]]] = ( | |
| prediction["score"] / total_prop | |
| ) | |
| gained_prob += prediction["score"] | |
| if gained_prob < total_prop + 1e-5: | |
| label_dict["Other"] = (total_prop - gained_prob) / total_prop | |
| return df, label_dict | |
| examples_titles = [ | |
| "Survey on Semantic Stereo Matching", | |
| ] | |
| examples_summaries = [ | |
| """Stereo matching is one of the widely used techniques for inferring depth from | |
| stereo images owing to its robustness and speed. It has become one of the major | |
| topics of research since it finds its applications in autonomous driving, | |
| robotic navigation, 3D reconstruction, and many other fields. Finding pixel | |
| correspondences in non-textured, occluded and reflective areas is the major | |
| challenge in stereo matching. Recent developments have shown that semantic cues | |
| from image segmentation can be used to improve the results of stereo matching. | |
| Many deep neural network architectures have been proposed to leverage the | |
| advantages of semantic segmentation in stereo matching. This paper aims to give | |
| a comparison among the state of art networks both in terms of accuracy and in | |
| terms of speed which are of higher importance in real-time applications.""", | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 640px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(" # Text-to-Image Gradio Template") | |
| gr.Markdown( | |
| "This space classifies scientific machine learning papers into categories based on their title and abstract." | |
| "Bar plot shows probabilities of belonging to each category." | |
| ) | |
| gr.Markdown( | |
| "Second thing predicts most probable single class classification. It shows only first 95\% of categories." | |
| ) | |
| title_prompt = gr.Text( | |
| label="Title Prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter paper's title", | |
| container=False, | |
| ) | |
| summary_prompt = gr.Text( | |
| label="Summary Prompt", | |
| show_label=False, | |
| max_lines=10, | |
| placeholder="Enter paper's abstract", | |
| container=False, | |
| ) | |
| run_button = gr.Button("Run", scale=0, variant="primary") | |
| result_bar = gr.BarPlot( | |
| label="Multi class classification", | |
| show_label=True, | |
| x="label", | |
| y="score", | |
| x_label_angle=30, | |
| ) | |
| result_label = gr.Label(label="Single class selection") | |
| # with gr.Accordion("Advanced Settings", open=False): | |
| gr.Examples(examples=examples_titles, inputs=[title_prompt]) | |
| gr.Examples(examples=examples_summaries, inputs=[summary_prompt]) | |
| gr.on( | |
| triggers=[run_button.click, title_prompt.submit, summary_prompt.submit], | |
| fn=infer, | |
| inputs=[ | |
| title_prompt, | |
| summary_prompt, | |
| ], | |
| outputs=[result_bar, result_label], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |