Update app.py
Browse files
app.py
CHANGED
|
@@ -9,9 +9,6 @@ import time
|
|
| 9 |
|
| 10 |
is_stopped = False
|
| 11 |
|
| 12 |
-
# seed = random.randint(0,100000)
|
| 13 |
-
setup_seed(4)
|
| 14 |
-
|
| 15 |
def temperature_sampling(logits, temperature):
|
| 16 |
logits = logits / temperature
|
| 17 |
probabilities = torch.softmax(logits, dim=-1)
|
|
@@ -23,7 +20,12 @@ def stop_generation():
|
|
| 23 |
is_stopped = True
|
| 24 |
return "Generation stopped."
|
| 25 |
|
| 26 |
-
def CTXGen(X0, X3, X1, X2, τ, g_num, model_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
global is_stopped
|
| 28 |
is_stopped = False
|
| 29 |
|
|
@@ -162,7 +164,7 @@ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name):
|
|
| 162 |
'Subtype_probability': cls_probability_all,
|
| 163 |
'Potency': X2,
|
| 164 |
'Potency_probability': act_probability_all,
|
| 165 |
-
'Random_seed': seed
|
| 166 |
})
|
| 167 |
out.to_csv("output.csv", index=False, encoding='utf-8-sig')
|
| 168 |
count += 1
|
|
@@ -198,6 +200,7 @@ with gr.Blocks() as demo:
|
|
| 198 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
| 199 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
| 200 |
model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
|
|
|
|
| 201 |
with gr.Row():
|
| 202 |
start_button = gr.Button("Start Generation")
|
| 203 |
stop_button = gr.Button("Stop Generation")
|
|
@@ -206,7 +209,7 @@ with gr.Blocks() as demo:
|
|
| 206 |
with gr.Row():
|
| 207 |
output_df = gr.DataFrame(label="Generated Conotoxins")
|
| 208 |
|
| 209 |
-
start_button.click(CTXGen, inputs=[X0, X3, X1, X2, τ, g_num, model_name], outputs=[output_file, output_df])
|
| 210 |
stop_button.click(stop_generation, outputs=None)
|
| 211 |
|
| 212 |
demo.launch()
|
|
|
|
| 9 |
|
| 10 |
is_stopped = False
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
def temperature_sampling(logits, temperature):
|
| 13 |
logits = logits / temperature
|
| 14 |
probabilities = torch.softmax(logits, dim=-1)
|
|
|
|
| 20 |
is_stopped = True
|
| 21 |
return "Generation stopped."
|
| 22 |
|
| 23 |
+
def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
| 24 |
+
if seed =='random'
|
| 25 |
+
seed = random.randint(0,100000)
|
| 26 |
+
setup_seed(seed)
|
| 27 |
+
else:
|
| 28 |
+
setup_seed(int(seed))
|
| 29 |
global is_stopped
|
| 30 |
is_stopped = False
|
| 31 |
|
|
|
|
| 164 |
'Subtype_probability': cls_probability_all,
|
| 165 |
'Potency': X2,
|
| 166 |
'Potency_probability': act_probability_all,
|
| 167 |
+
'Random_seed': int(seed)
|
| 168 |
})
|
| 169 |
out.to_csv("output.csv", index=False, encoding='utf-8-sig')
|
| 170 |
count += 1
|
|
|
|
| 200 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
| 201 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
| 202 |
model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
|
| 203 |
+
seed = gr.Textbox(label="Seed", value="random")
|
| 204 |
with gr.Row():
|
| 205 |
start_button = gr.Button("Start Generation")
|
| 206 |
stop_button = gr.Button("Stop Generation")
|
|
|
|
| 209 |
with gr.Row():
|
| 210 |
output_df = gr.DataFrame(label="Generated Conotoxins")
|
| 211 |
|
| 212 |
+
start_button.click(CTXGen, inputs=[X0, X3, X1, X2, τ, g_num, model_name, seed], outputs=[output_file, output_df])
|
| 213 |
stop_button.click(stop_generation, outputs=None)
|
| 214 |
|
| 215 |
demo.launch()
|