ssocean commited on
Commit
984938f
·
verified ·
1 Parent(s): 757c76c

update NAIPv2

Browse files
Files changed (1) hide show
  1. app.py +109 -49
app.py CHANGED
@@ -2,66 +2,99 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
5
  import torch.nn.functional as F
6
- import torch.nn as nn
7
  import re
8
- model_path = r'ssocean/NAIP'
9
  device = 'cuda:0'
10
 
11
- global model, tokenizer
12
- model = None
13
- tokenizer = None
14
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- @spaces.GPU(duration=60, enable_queue=True)
17
- def predict(title, abstract):
18
- title = title.replace("\n", " ").strip().replace('’',"'")
19
- abstract = abstract.replace("\n", " ").strip().replace('’',"'")
20
- global model, tokenizer
21
- if model is None:
22
- model = AutoModelForSequenceClassification.from_pretrained(
23
- model_path,
24
- num_labels=1,
25
- load_in_8bit=True,)
26
- tokenizer = AutoTokenizer.from_pretrained(model_path)
27
- model.eval()
28
- print(title + '\n' + abstract)
29
- text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
30
- inputs = tokenizer(text, return_tensors="pt").to(device)
31
  with torch.no_grad():
32
- outputs = model(**inputs)
33
- probability = torch.sigmoid(outputs.logits).item()
34
- # reason for +0.05: We observed that the predicted values in the web demo are generally around 0.05 lower than those in the local deployment (due to differences in software/hardware environments, we believed). Therefore, we applied the following compensation in the web demo. Please do not use this in the local deployment.
35
- if probability + 0.05 >=1.0:
36
  return round(1, 4)
37
- return round(probability + 0.05, 4)
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- examples = [
42
- [
43
- "SARDet-100K: Towards Open-Source Benchmark and ToolKit for Large-Scale SAR Object Detection",
44
- ('''Synthetic Aperture Radar (SAR) object detection has gained significant attention recently due to its irreplaceable all-weather imaging capabilities. However, this research field suffers from both limited public datasets (mostly comprising <2K images with only mono-category objects) and inaccessible source code. To tackle these challenges, we establish a new benchmark dataset and an open-source method for large-scale SAR object detection. Our dataset, SARDet-100K, is a result of intense surveying, collecting, and standardizing 10 existing SAR detection datasets, providing a large-scale and diverse dataset for research purposes. To the best of our knowledge, SARDet-100K is the first COCO-level large-scale multi-class SAR object detection dataset ever created. With this high-quality dataset, we conducted comprehensive experiments and uncovered a crucial challenge in SAR object detection: the substantial disparities between the pretraining on RGB datasets and finetuning on SAR datasets in terms of both data domain and model structure. To bridge these gaps, we propose a novel Multi-Stage with Filter Augmentation (MSFA) pretraining framework that tackles the problems from the perspective of data input, domain transition, and model migration. The proposed MSFA method significantly enhances the performance of SAR object detection models while demonstrating exceptional generalizability and flexibility across diverse models. This work aims to pave the way for further advancements in SAR object detection. The dataset and code is available at this https URL.''')
45
- ],
46
- [
47
- "OminiControl: Minimal and Universal Control for Diffusion Transformer",
48
- ('''In this paper, we introduce OminiControl, a highly versatile and parameter-efficient framework that integrates image conditions into pre-trained Diffusion Transformer (DiT) models. At its core, OminiControl leverages a parameter reuse mechanism, enabling the DiT to encode image conditions using itself as a powerful backbone and process them with its flexible multi-modal attention processors. Unlike existing methods, which rely heavily on additional encoder modules with complex architectures, OminiControl (1) effectively and efficiently incorporates injected image conditions with only ~0.1% additional parameters, and (2) addresses a wide range of image conditioning tasks in a unified manner, including subject-driven generation and spatially-aligned conditions such as edges, depth, and more. Remarkably, these capabilities are achieved by training on images generated by the DiT itself, which is particularly beneficial for subject-driven generation. Extensive evaluations demonstrate that OminiControl outperforms existing UNet-based and DiT-adapted models in both subject-driven and spatially-aligned conditional generation. Additionally, we release our training dataset, Subjects200K, a diverse collection of over 200,000 identity-consistent images, along with an efficient data synthesis pipeline to advance research in subject-consistent generation.''')
49
- ],
50
- [
51
- "Enhanced ZSSR for Super-resolution Reconstruction of the Historical Tibetan Document Images",
52
- "Due to the poor preservation and imaging conditions, the image quality of historical Tibetan document images is relatively unsatisfactory. In this paper, we adopt super-resolution technology to reconstruct high quality images of historical Tibetan document. To address the problem of low quantity and poor quality of historical Tibetan document images, we propose the EZSSR network based on the Zero-Shot Super-resolution Network (ZSSR), which borrows the idea of feature pyramid in Deep Laplacian Pyramid Networks (LapSRN) to extract different levels of features while alleviating the ringing artifacts. EZSSR neither requires paired training datasets nor preprocessing stage. The computational complexity of EZSSR is low, and thus, EZSSR can also reconstruct image within the acceptable time frame. Experimental results show that EZSSR reconstructs images with better visual effects and higher PSNR and SSIM values."
53
- ]
54
-
55
- ]
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def validate_input(title, abstract):
58
- title = title.replace("\n", " ").strip().replace('’',"'")
59
- abstract = abstract.replace("\n", " ").strip().replace('’',"'")
60
 
61
  non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
62
  non_latin_in_title = non_latin_pattern.findall(title)
63
  non_latin_in_abstract = non_latin_pattern.findall(abstract)
64
-
65
  if len(title.strip().split(' ')) < 3:
66
  return False, "The title must be at least 3 words long."
67
  if len(abstract.strip().split(' ')) < 50:
@@ -72,16 +105,35 @@ def validate_input(title, abstract):
72
  return False, f"The title contains invalid characters: {', '.join(non_latin_in_title)}. Only English letters and special symbols are allowed."
73
  if non_latin_in_abstract:
74
  return False, f"The abstract contains invalid characters: {', '.join(non_latin_in_abstract)}. Only English letters and special symbols are allowed."
75
-
76
  return True, "Inputs are valid! Good to go!"
77
 
78
- def update_button_status(title, abstract):
79
 
 
80
  valid, message = validate_input(title, abstract)
81
  if not valid:
82
  return gr.update(value="Error: " + message), gr.update(interactive=False)
83
  return gr.update(value=message), gr.update(interactive=True)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  with gr.Blocks() as iface:
86
  gr.Markdown("""
87
  # 📈 Predict Academic Impact of Newly Published Paper!
@@ -89,20 +141,27 @@ with gr.Blocks() as iface:
89
  ###### [Full Paper](https://arxiv.org/abs/2408.03934)
90
  ###### Please be advised: Local inference of the proposed method is instant, but ZeroGPU requires quantized model reinitialization with each "Predict", causing slight delays. (typically wont take more than 30 secs)
91
  """)
 
92
  with gr.Row():
93
  with gr.Column():
 
 
 
 
 
94
  title_input = gr.Textbox(
95
  lines=2,
96
- placeholder='''Enter Paper Title Here... (Title will be processed with 'title.replace("\\n", " ").strip()')''',
97
  label="Paper Title"
98
  )
99
  abstract_input = gr.Textbox(
100
  lines=5,
101
- placeholder='''Enter Paper Abstract Here... (Abstract will be processed with 'abstract.replace("\\n", " ").strip()')''',
102
  label="Paper Abstract"
103
  )
104
  validation_status = gr.Textbox(label="Validation Status", interactive=False)
105
  submit_button = gr.Button("Predict Impact", interactive=False)
 
106
  with gr.Column():
107
  output = gr.Label(label="Predicted Impact")
108
  gr.Markdown("""
@@ -116,7 +175,6 @@ with gr.Blocks() as iface:
116
  - The **author takes NO responsibility** for the prediction results.
117
  """)
118
 
119
-
120
  title_input.change(
121
  update_button_status,
122
  inputs=[title_input, abstract_input],
@@ -130,7 +188,7 @@ with gr.Blocks() as iface:
130
 
131
  submit_button.click(
132
  predict,
133
- inputs=[title_input, abstract_input],
134
  outputs=output
135
  )
136
 
@@ -140,4 +198,6 @@ with gr.Blocks() as iface:
140
  outputs=[validation_status, output],
141
  cache_examples=False
142
  )
 
143
  iface.launch()
 
 
2
  import spaces
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ from peft import AutoPeftModelForSequenceClassification
6
  import torch.nn.functional as F
 
7
  import re
8
+
9
  device = 'cuda:0'
10
 
11
+ # ===== v1 部分 =====
12
+ model_v1, tokenizer_v1 = None, None
13
+ model_path_v1 = r'ssocean/NAIP'
14
 
15
+ def predict_v1(title, abstract):
16
+ global model_v1, tokenizer_v1
17
+ if model_v1 is None:
18
+ model_v1 = AutoModelForSequenceClassification.from_pretrained(
19
+ model_path_v1,
20
+ num_labels=1,
21
+ load_in_8bit=True,
22
+ ).to(device)
23
+ tokenizer_v1 = AutoTokenizer.from_pretrained(model_path_v1)
24
+ model_v1.eval()
25
 
26
+ text = f"Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):"
27
+ inputs = tokenizer_v1(text, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  with torch.no_grad():
29
+ outputs = model_v1(**inputs)
30
+ prob = torch.sigmoid(outputs.logits).item()
31
+ if prob + 0.05 >= 1.0:
 
32
  return round(1, 4)
33
+ return round(prob + 0.05, 4)
34
 
35
 
36
+ # ===== v2 部分 =====
37
+ scorer_v2 = None
38
+ model_path_v2 = r'ssocean/NAIPv2'
39
+
40
+ class PaperScorer:
41
+ def __init__(self, model_path: str, device: str = 'cuda', max_length: int = 512):
42
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
43
+ self.max_length = max_length
44
+
45
+ # PEFT 模型 (LoRA)
46
+ self.model = AutoPeftModelForSequenceClassification.from_pretrained(
47
+ model_path,
48
+ device_map="auto",
49
+ num_labels=1
50
+ ).to(self.device).eval()
51
+
52
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
53
+ self.tokenizer.pad_token = self.tokenizer.eos_token
54
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
55
+
56
+ self.prompt_template = (
57
+ "Given a certain paper, Title: {title}\nAbstract: {abstract}\nEvaluate the quality of this paper:"
58
+ )
59
+
60
+ def score(self, title: str, abstract: str) -> float:
61
+ prompt = self.prompt_template.format(title=title.strip(), abstract=abstract.strip())
62
+ inputs = self.tokenizer(prompt, return_tensors='pt', padding=True, truncation=True,
63
+ max_length=self.max_length).to(self.device)
64
+ with torch.no_grad():
65
+ logits = self.model(**inputs).logits + 1.3
66
+ score = torch.sigmoid(logits).view(-1).item()
67
+ return round(score, 4)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ def predict_v2(title, abstract):
71
+ global scorer_v2
72
+ if scorer_v2 is None:
73
+ scorer_v2 = PaperScorer(model_path_v2, device=device)
74
+ return scorer_v2.score(title, abstract)
75
+
76
+
77
+ # ===== 统一接口 =====
78
+ @spaces.GPU(duration=60, enable_queue=True)
79
+ def predict(title, abstract, model_version):
80
+ title = title.replace("\n", " ").strip().replace('’', "'")
81
+ abstract = abstract.replace("\n", " ").strip().replace('’', "'")
82
+
83
+ if model_version == "v1":
84
+ return predict_v1(title, abstract)
85
+ else:
86
+ return predict_v2(title, abstract)
87
+
88
+
89
+ # ===== 输入校验 =====
90
  def validate_input(title, abstract):
91
+ title = title.replace("\n", " ").strip().replace('’', "'")
92
+ abstract = abstract.replace("\n", " ").strip().replace('’', "'")
93
 
94
  non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
95
  non_latin_in_title = non_latin_pattern.findall(title)
96
  non_latin_in_abstract = non_latin_pattern.findall(abstract)
97
+
98
  if len(title.strip().split(' ')) < 3:
99
  return False, "The title must be at least 3 words long."
100
  if len(abstract.strip().split(' ')) < 50:
 
105
  return False, f"The title contains invalid characters: {', '.join(non_latin_in_title)}. Only English letters and special symbols are allowed."
106
  if non_latin_in_abstract:
107
  return False, f"The abstract contains invalid characters: {', '.join(non_latin_in_abstract)}. Only English letters and special symbols are allowed."
108
+
109
  return True, "Inputs are valid! Good to go!"
110
 
 
111
 
112
+ def update_button_status(title, abstract):
113
  valid, message = validate_input(title, abstract)
114
  if not valid:
115
  return gr.update(value="Error: " + message), gr.update(interactive=False)
116
  return gr.update(value=message), gr.update(interactive=True)
117
 
118
+
119
+ # ===== 示例数据 =====
120
+ examples = [
121
+ [
122
+ "SARDet-100K: Towards Open-Source Benchmark and ToolKit for Large-Scale SAR Object Detection",
123
+ ('''Synthetic Aperture Radar (SAR) object detection has gained significant attention recently due to its irreplaceable all-weather imaging capabilities. However, this research field suffers from both limited public datasets (mostly comprising <2K images with only mono-category objects) and inaccessible source code. To tackle these challenges, we establish a new benchmark dataset and an open-source method for large-scale SAR object detection. Our dataset, SARDet-100K, is a result of intense surveying, collecting, and standardizing 10 existing SAR detection datasets, providing a large-scale and diverse dataset for research purposes. To the best of our knowledge, SARDet-100K is the first COCO-level large-scale multi-class SAR object detection dataset ever created. With this high-quality dataset, we conducted comprehensive experiments and uncovered a crucial challenge in SAR object detection: the substantial disparities between the pretraining on RGB datasets and finetuning on SAR datasets in terms of both data domain and model structure. To bridge these gaps, we propose a novel Multi-Stage with Filter Augmentation (MSFA) pretraining framework that tackles the problems from the perspective of data input, domain transition, and model migration. The proposed MSFA method significantly enhances the performance of SAR object detection models while demonstrating exceptional generalizability and flexibility across diverse models. This work aims to pave the way for further advancements in SAR object detection. The dataset and code is available at this https URL.''')
124
+ ],
125
+ [
126
+ "OminiControl: Minimal and Universal Control for Diffusion Transformer",
127
+ ('''In this paper, we introduce OminiControl, a highly versatile and parameter-efficient framework that integrates image conditions into pre-trained Diffusion Transformer (DiT) models. At its core, OminiControl leverages a parameter reuse mechanism, enabling the DiT to encode image conditions using itself as a powerful backbone and process them with its flexible multi-modal attention processors. Unlike existing methods, which rely heavily on additional encoder modules with complex architectures, OminiControl (1) effectively and efficiently incorporates injected image conditions with only ~0.1% additional parameters, and (2) addresses a wide range of image conditioning tasks in a unified manner, including subject-driven generation and spatially-aligned conditions such as edges, depth, and more. Remarkably, these capabilities are achieved by training on images generated by the DiT itself, which is particularly beneficial for subject-driven generation. Extensive evaluations demonstrate that OminiControl outperforms existing UNet-based and DiT-adapted models in both subject-driven and spatially-aligned conditional generation. Additionally, we release our training dataset, Subjects200K, a diverse collection of over 200,000 identity-consistent images, along with an efficient data synthesis pipeline to advance research in subject-consistent generation.''')
128
+ ],
129
+ [
130
+ "Enhanced ZSSR for Super-resolution Reconstruction of the Historical Tibetan Document Images",
131
+ "Due to the poor preservation and imaging conditions, the image quality of historical Tibetan document images is relatively unsatisfactory. In this paper, we adopt super-resolution technology to reconstruct high quality images of historical Tibetan document. To address the problem of low quantity and poor quality of historical Tibetan document images, we propose the EZSSR network based on the Zero-Shot Super-resolution Network (ZSSR), which borrows the idea of feature pyramid in Deep Laplacian Pyramid Networks (LapSRN) to extract different levels of features while alleviating the ringing artifacts. EZSSR neither requires paired training datasets nor preprocessing stage. The computational complexity of EZSSR is low, and thus, EZSSR can also reconstruct image within the acceptable time frame. Experimental results show that EZSSR reconstructs images with better visual effects and higher PSNR and SSIM values."
132
+ ]
133
+
134
+ ]
135
+
136
+ # ===== Gradio 界面 =====
137
  with gr.Blocks() as iface:
138
  gr.Markdown("""
139
  # 📈 Predict Academic Impact of Newly Published Paper!
 
141
  ###### [Full Paper](https://arxiv.org/abs/2408.03934)
142
  ###### Please be advised: Local inference of the proposed method is instant, but ZeroGPU requires quantized model reinitialization with each "Predict", causing slight delays. (typically wont take more than 30 secs)
143
  """)
144
+
145
  with gr.Row():
146
  with gr.Column():
147
+ model_selector = gr.Dropdown(
148
+ choices=["v1", "v2"],
149
+ value="v2", # 默认 v2
150
+ label="Select Model Version"
151
+ )
152
  title_input = gr.Textbox(
153
  lines=2,
154
+ placeholder="Enter Paper Title Here...",
155
  label="Paper Title"
156
  )
157
  abstract_input = gr.Textbox(
158
  lines=5,
159
+ placeholder="Enter Paper Abstract Here...",
160
  label="Paper Abstract"
161
  )
162
  validation_status = gr.Textbox(label="Validation Status", interactive=False)
163
  submit_button = gr.Button("Predict Impact", interactive=False)
164
+
165
  with gr.Column():
166
  output = gr.Label(label="Predicted Impact")
167
  gr.Markdown("""
 
175
  - The **author takes NO responsibility** for the prediction results.
176
  """)
177
 
 
178
  title_input.change(
179
  update_button_status,
180
  inputs=[title_input, abstract_input],
 
188
 
189
  submit_button.click(
190
  predict,
191
+ inputs=[title_input, abstract_input, model_selector],
192
  outputs=output
193
  )
194
 
 
198
  outputs=[validation_status, output],
199
  cache_examples=False
200
  )
201
+
202
  iface.launch()
203
+