Hacker1337 commited on
Commit
7c398ad
·
1 Parent(s): 0c21c10

Debugged app locally and customized it.

Browse files
Files changed (2) hide show
  1. app.py +65 -46
  2. dataset.py +12 -6
app.py CHANGED
@@ -1,15 +1,13 @@
1
  import os
2
- from transformers import AutoModelForSequenceClassification
 
3
  import gradio as gr
4
- import numpy as np
5
- import random
6
 
7
  # import spaces #[uncomment to use ZeroGPU]
8
  # from diffusers import DiffusionPipeline
9
  import torch
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model_repo_id = "stabilityai/sdxl-turbo"
13
 
14
  if torch.cuda.is_available():
15
  torch_dtype = torch.float16
@@ -17,49 +15,61 @@ else:
17
  torch_dtype = torch.float32
18
 
19
 
20
- from dataset import labels, id2label, label2id, categorie2human
21
 
22
- # model_path = "distilbert/distilbert-base-cased" # todo, replace with hacker1337/article-classifier
23
- # model_path = os.path.expanduser(r"~\cache\huggingface\checkpoints\distilbert-arxiv\runs\Jun21_00-41-18_amir-xp")
24
- model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv")
25
 
26
- model = AutoModelForSequenceClassification.from_pretrained(
27
- model_path,
28
- num_labels=len(id2label),
29
- id2label=id2label,
30
- label2id=label2id,
31
- problem_type="multi_label_classification",
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
 
34
 
35
  # @spaces.GPU #[uncomment to use ZeroGPU]
36
  def infer(
37
- prompt,
38
- negative_prompt,
39
- seed,
40
- randomize_seed,
41
- width,
42
- height,
43
- guidance_scale,
44
- num_inference_steps,
45
  progress=gr.Progress(track_tqdm=True),
46
  ):
47
- if randomize_seed:
48
- seed = random.randint(0, MAX_SEED)
49
-
50
- generator = torch.Generator().manual_seed(seed)
51
-
52
- # image = pipe(
53
- # prompt=prompt,
54
- # negative_prompt=negative_prompt,
55
- # guidance_scale=guidance_scale,
56
- # num_inference_steps=num_inference_steps,
57
- # width=width,
58
- # height=height,
59
- # generator=generator,
60
- # ).images[0]
61
-
62
- # return image, seed
 
 
 
 
 
 
 
 
63
 
64
 
65
  examples_titles = [
@@ -79,11 +89,6 @@ a comparison among the state of art networks both in terms of accuracy and in
79
  terms of speed which are of higher importance in real-time applications.""",
80
  ]
81
 
82
- examples = [
83
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
84
- "An astronaut riding a green horse",
85
- "A delicious ceviche cheesecake slice",
86
- ]
87
 
88
  css = """
89
  #col-container {
@@ -95,6 +100,13 @@ css = """
95
  with gr.Blocks(css=css) as demo:
96
  with gr.Column(elem_id="col-container"):
97
  gr.Markdown(" # Text-to-Image Gradio Template")
 
 
 
 
 
 
 
98
 
99
  title_prompt = gr.Text(
100
  label="Title Prompt",
@@ -111,10 +123,17 @@ with gr.Blocks(css=css) as demo:
111
  container=False,
112
  )
113
 
114
-
115
  run_button = gr.Button("Run", scale=0, variant="primary")
116
 
117
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
118
 
119
  # with gr.Accordion("Advanced Settings", open=False):
120
 
@@ -127,7 +146,7 @@ with gr.Blocks(css=css) as demo:
127
  title_prompt,
128
  summary_prompt,
129
  ],
130
- outputs=[result],
131
  )
132
 
133
  if __name__ == "__main__":
 
1
  import os
2
+ import socket
3
+ import pandas as pd
4
  import gradio as gr
 
 
5
 
6
  # import spaces #[uncomment to use ZeroGPU]
7
  # from diffusers import DiffusionPipeline
8
  import torch
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
 
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
 
15
  torch_dtype = torch.float32
16
 
17
 
18
+ from dataset import category2human, create_prompt
19
 
20
+ LOCAL_COMPUTER_NAMES = ["amir-xps"]
 
 
21
 
22
+
23
+ def is_local_machine():
24
+ return socket.gethostname().lower() in [
25
+ name.lower() for name in LOCAL_COMPUTER_NAMES
26
+ ]
27
+
28
+
29
+ if is_local_machine():
30
+ model_path = os.path.expanduser("~/.cache/huggingface/checkpoints/distilbert-arxiv")
31
+ else:
32
+ model_path = "Hacker1337/distilbert-arxiv-checkpoint"
33
+
34
+ from transformers import pipeline
35
+
36
+ classifier = pipeline(
37
+ "text-classification",
38
+ model=model_path,
39
+ tokenizer=model_path,
40
  )
41
 
42
 
43
  # @spaces.GPU #[uncomment to use ZeroGPU]
44
  def infer(
45
+ title_prompt,
46
+ summary_prompt,
 
 
 
 
 
 
47
  progress=gr.Progress(track_tqdm=True),
48
  ):
49
+ sample_prompt_full = create_prompt(
50
+ title_prompt,
51
+ summary_prompt,
52
+ )
53
+ predictions = classifier(sample_prompt_full, top_k=None)
54
+ target_probs_sum = 0.95
55
+ print(predictions)
56
+ df = pd.DataFrame(predictions)
57
+ df["label"] = df["label"].apply(lambda x: category2human[x])
58
+ label_dict = {}
59
+ bar_plot_dict = {}
60
+ total_prop = sum([prediction["score"] for prediction in predictions])
61
+ gained_prob = 0
62
+ for prediction in predictions:
63
+ bar_plot_dict[prediction["label"]] = prediction["score"]
64
+ if (gained_prob + prediction["score"]) / total_prop < target_probs_sum:
65
+ label_dict[category2human[prediction["label"]]] = (
66
+ prediction["score"] / total_prop
67
+ )
68
+ gained_prob += prediction["score"]
69
+
70
+ if gained_prob < total_prop:
71
+ label_dict["Other"] = (total_prop - gained_prob) / total_prop
72
+ return df, label_dict
73
 
74
 
75
  examples_titles = [
 
89
  terms of speed which are of higher importance in real-time applications.""",
90
  ]
91
 
 
 
 
 
 
92
 
93
  css = """
94
  #col-container {
 
100
  with gr.Blocks(css=css) as demo:
101
  with gr.Column(elem_id="col-container"):
102
  gr.Markdown(" # Text-to-Image Gradio Template")
103
+ gr.Markdown(
104
+ "This space classifies scientific machine learning papers into categories based on their title and abstract."
105
+ "Bar plot shows probabilities of belonging to each category."
106
+ )
107
+ gr.Markdown(
108
+ "Second thing predicts most probable single class classification. It shows only first 95\% of categories."
109
+ )
110
 
111
  title_prompt = gr.Text(
112
  label="Title Prompt",
 
123
  container=False,
124
  )
125
 
 
126
  run_button = gr.Button("Run", scale=0, variant="primary")
127
 
128
+ result_bar = gr.BarPlot(
129
+ label="Multi class classification",
130
+ show_label=True,
131
+ x="label",
132
+ y="score",
133
+ x_label_angle=30,
134
+ )
135
+
136
+ result_label = gr.Label(label="Single class selection")
137
 
138
  # with gr.Accordion("Advanced Settings", open=False):
139
 
 
146
  title_prompt,
147
  summary_prompt,
148
  ],
149
+ outputs=[result_bar, result_label],
150
  )
151
 
152
  if __name__ == "__main__":
dataset.py CHANGED
@@ -1,13 +1,14 @@
1
  labels = ["CV", "AI", "ML", "NE", "CL"]
 
2
  id2label = {i: label for i, label in enumerate(labels)}
3
  label2id = {label: i for i, label in enumerate(labels)}
4
 
5
- categorie2human = {
6
  "CV": "Computer Vision",
7
  "AI": "Artificial Intelligence",
8
  "ML": "Machine Learning",
9
  "NE": "Neural and Evolutionary Computing",
10
- "CL": "Computation and Language"
11
  }
12
 
13
 
@@ -19,13 +20,11 @@ def load_arxiv_dataset():
19
  # Download latest version
20
  path = kagglehub.dataset_download("spsayakpaul/arxiv-paper-abstracts")
21
 
22
-
23
-
24
  dataset = load_dataset(
25
  "csv",
26
  data_files=os.path.join(path, "arxiv_data.csv"),
27
  encoding="utf-8",
28
- split="train"
29
  )
30
 
31
  # convert string to lists
@@ -37,4 +36,11 @@ def load_arxiv_dataset():
37
 
38
  dataset = dataset.map(parse_terms)
39
 
40
- return dataset
 
 
 
 
 
 
 
 
1
  labels = ["CV", "AI", "ML", "NE", "CL"]
2
+
3
  id2label = {i: label for i, label in enumerate(labels)}
4
  label2id = {label: i for i, label in enumerate(labels)}
5
 
6
+ category2human = {
7
  "CV": "Computer Vision",
8
  "AI": "Artificial Intelligence",
9
  "ML": "Machine Learning",
10
  "NE": "Neural and Evolutionary Computing",
11
+ "CL": "Computation and Language",
12
  }
13
 
14
 
 
20
  # Download latest version
21
  path = kagglehub.dataset_download("spsayakpaul/arxiv-paper-abstracts")
22
 
 
 
23
  dataset = load_dataset(
24
  "csv",
25
  data_files=os.path.join(path, "arxiv_data.csv"),
26
  encoding="utf-8",
27
+ split="train",
28
  )
29
 
30
  # convert string to lists
 
36
 
37
  dataset = dataset.map(parse_terms)
38
 
39
+ return dataset
40
+
41
+
42
+ def create_prompt(title, summary):
43
+ """
44
+ Create a prompt for the model from the title and summary.
45
+ """
46
+ return f"# title:\n{title}\n# abstract:\n{summary}"