PhDFlo commited on
Commit
ae4465f
·
1 Parent(s): 0a38dda

Include json config

Browse files
Files changed (1) hide show
  1. app.py +46 -27
app.py CHANGED
@@ -46,28 +46,29 @@ def create_fasta_file(sequence: str, name: Optional[str] = None) -> str:
46
 
47
  return file_name
48
 
49
- # Function to create a custom JSON config file
50
- def create_custom_config(
51
- num_trunk_recycles: int = 3,
52
- num_diffn_timesteps: int = 200,
53
- seed: int = 42,
54
- use_esm_embeddings: bool = True,
55
- use_msa_server: bool = True,
56
- output_file: Optional[str] = None
57
- ) -> str:
58
- """Create a custom JSON configuration file for Chai1 inference.
59
 
60
  Args:
61
- num_trunk_recycles (int, optional): Number of trunk recycles. Defaults to 3.
62
- num_diffn_timesteps (int, optional): Number of diffusion timesteps. Defaults to 200.
63
- seed (int, optional): Random seed. Defaults to 42.
64
- use_esm_embeddings (bool, optional): Whether to use ESM embeddings. Defaults to True.
65
- use_msa_server (bool, optional): Whether to use MSA server. Defaults to True.
66
- output_file (str, optional): Path to save the config file. If None, saves to default location.
67
 
68
  Returns:
69
- str: Path to the created config file
70
  """
 
 
 
 
 
71
  config = {
72
  "num_trunk_recycles": num_trunk_recycles,
73
  "num_diffn_timesteps": num_diffn_timesteps,
@@ -76,13 +77,17 @@ def create_custom_config(
76
  "use_msa_server": use_msa_server
77
  }
78
 
79
- if output_file is None:
80
- output_file = here / "inputs" / "chai1_custom_inference.json"
81
-
82
- with open(output_file, "w") as f:
 
 
 
83
  json.dump(config, f, indent=4)
84
-
85
- return str(output_file)
 
86
 
87
  # Function to compute Chai1 inference
88
  def compute_Chai1(
@@ -115,6 +120,7 @@ def compute_Chai1(
115
  # Define inference config file
116
  if not inference_config_file:
117
  inference_config_file = here / "inputs" / "chai1_quick_inference.json"
 
118
  print(f"🧬 loading Chai inference config from {inference_config_file}")
119
  inference_config = json.loads(Path(inference_config_file).read_text())
120
 
@@ -154,11 +160,24 @@ with gr.Blocks() as demo:
154
  """)
155
 
156
  with gr.Tab("Configuration 📦"):
157
- text_input = gr.Textbox(placeholder="Fasta format sequences", label="Fasta content", lines=10)
158
- text_output = gr.Textbox(placeholder="Fasta file name", label="Fasta file name")
159
- text_button = gr.Button("Create Fasta file")
160
- text_button.click(fn=create_fasta_file, inputs=[text_input], outputs=[text_output])
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  gr.Markdown(
163
  """
164
  You can input a Fasta file containing the sequence of the molecule you want to simulate.
 
46
 
47
  return file_name
48
 
49
+ # Function to create a JSON file
50
+ def create_json_config(
51
+ num_diffn_timesteps: int,
52
+ num_trunk_recycles: int,
53
+ seed: int,
54
+ options: list
55
+ ) -> str:
56
+ """Create a JSON configuration file from the Gradio interface inputs.
 
 
57
 
58
  Args:
59
+ num_diffn_timesteps (int): Number of diffusion timesteps from slider
60
+ num_trunk_recycles (int): Number of trunk recycles from slider
61
+ seed (int): Random seed from slider
62
+ options (list): List of selected options from checkbox group
 
 
63
 
64
  Returns:
65
+ str: Name of the created JSON file
66
  """
67
+ # Convert checkbox options to boolean flags
68
+ use_esm_embeddings = "ESM_embeddings" in options
69
+ use_msa_server = "MSA_server" in options
70
+
71
+ # Create config dictionary
72
  config = {
73
  "num_trunk_recycles": num_trunk_recycles,
74
  "num_diffn_timesteps": num_diffn_timesteps,
 
77
  "use_msa_server": use_msa_server
78
  }
79
 
80
+ # Generate a unique file name
81
+ unique_id = hashlib.sha256(uuid4().bytes).hexdigest()[:8]
82
+ file_name = f"chai1_{unique_id}_config.json"
83
+ file_path = here / "inputs" / file_name
84
+
85
+ # Write the JSON file
86
+ with open(file_path, "w") as f:
87
  json.dump(config, f, indent=4)
88
+
89
+ return file_name
90
+
91
 
92
  # Function to compute Chai1 inference
93
  def compute_Chai1(
 
120
  # Define inference config file
121
  if not inference_config_file:
122
  inference_config_file = here / "inputs" / "chai1_quick_inference.json"
123
+ inference_config_file = here / "inputs" / inference_config_file
124
  print(f"🧬 loading Chai inference config from {inference_config_file}")
125
  inference_config = json.loads(Path(inference_config_file).read_text())
126
 
 
160
  """)
161
 
162
  with gr.Tab("Configuration 📦"):
 
 
 
 
163
 
164
+ with gr.Row():
165
+ with gr.Column(scale=1):
166
+ slider_nb = gr.Slider(1, 500, value=200, label="Number of diffusion time steps", info="Choose the number of diffusion time steps for the simulation", step=1, interactive=True, elem_id="num_iterations")
167
+ slider_trunk = gr.Slider(1, 5, value=3, label="Number of trunk recycles", info="Choose the number of iterations for the simulation", step=1, interactive=True, elem_id="trunk_number")
168
+ slider_seed = gr.Slider(1, 100, value=42, label="Seed", info="Choose the seed", step=1, interactive=True, elem_id="seed")
169
+ check_options = gr.CheckboxGroup(["ESM_embeddings", "MSA_server"], value=["ESM_embeddings",], label="Additionnal options", info="Options to use ESM embeddings and MSA server", elem_id="options")
170
+ json_output = gr.Textbox(placeholder="Config file name", label="Config file name")
171
+ button_json = gr.Button("Create Config file")
172
+ button_json.click(fn=create_json_config, inputs=[slider_nb, slider_trunk, slider_seed, check_options], outputs=[json_output])
173
+
174
+ with gr.Column(scale=1):
175
+ text_input = gr.Textbox(placeholder="Fasta format sequences", label="Fasta content", lines=10)
176
+ text_output = gr.Textbox(placeholder="Fasta file name", label="Fasta file name")
177
+ text_button = gr.Button("Create Fasta file")
178
+ text_button.click(fn=create_fasta_file, inputs=[text_input], outputs=[text_output])
179
+
180
+
181
  gr.Markdown(
182
  """
183
  You can input a Fasta file containing the sequence of the molecule you want to simulate.