gbyuvd commited on
Commit
e08caa0
·
verified ·
1 Parent(s): 4e12c74

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -9
README.md CHANGED
@@ -75,8 +75,8 @@ Please clone the repo first, then you can:
75
  ```python
76
  # ==============================
77
  # Generate SELFIES from ChemQ3MTP checkpoint
78
- # Uses exact same local loading as original training script
79
- # ==============================
80
 
81
  import sys
82
  import os
@@ -89,20 +89,20 @@ chemq3mtp_path = os.path.join(notebook_dir, "ChemQ3MTP")
89
  if chemq3mtp_path not in sys.path:
90
  sys.path.insert(0, chemq3mtp_path)
91
 
92
- # Optional: clean up duplicate paths (as in your training script)
93
  existing_paths = [p for p in sys.path if p.endswith("ChemQ3MTP")]
94
  for path in existing_paths[:-1]: # keep only the most recently added
95
  sys.path.remove(path)
96
 
97
  # Now import from local ChemQ3MTP folder
98
  from FastChemTokenizerHF import FastChemTokenizerSelfies
99
- from ChemQ3MTP import ChemQ3MTPForCausalLM # <-- your custom model
100
 
101
  # --- Load from checkpoint (same as saved in training) ---
102
- checkpoint_dir = "./enhanced-qwen3-final" # or your actual checkpoint path
103
 
104
- print(f"Loading tokenizer from {checkpoint_dir}...")
105
- tokenizer = FastChemTokenizerSelfies.from_pretrained(checkpoint_dir)
106
 
107
  print(f"Loading ChemQ3MTP model from {checkpoint_dir}...")
108
  model = ChemQ3MTPForCausalLM.from_pretrained(checkpoint_dir)
@@ -140,17 +140,24 @@ except Exception as e:
140
  print(f"Generation failed: {e}")
141
  import traceback
142
  traceback.print_exc()
 
 
 
 
 
 
143
  ```
144
 
145
- **and for visualisation:**
146
 
147
  ```python
148
  # Generate Mol Viz
149
  from rdkit import Chem
150
  from rdkit.Chem import Draw
 
151
 
152
  input_ids = tokenizer("<s>", return_tensors="pt").input_ids.to(device)
153
- gen = model.generate(input_ids, max_length=25, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)
154
  generatedmol = tokenizer.decode(gen[0], skip_special_tokens=True)
155
 
156
  test = generatedmol.replace(' ', '')
@@ -160,8 +167,13 @@ mol = Chem.MolFromSmiles(csmi_gen)
160
 
161
  # Draw the molecule
162
  Draw.MolToImage(mol)
 
 
163
  ```
164
 
 
 
 
165
  ---
166
 
167
  ## 📊 Model Architecture
 
75
  ```python
76
  # ==============================
77
  # Generate SELFIES from ChemQ3MTP checkpoint
78
+ # LOADING THE MODEL & TOKENIZER
79
+ # ================================
80
 
81
  import sys
82
  import os
 
89
  if chemq3mtp_path not in sys.path:
90
  sys.path.insert(0, chemq3mtp_path)
91
 
92
+ # Optional: clean up duplicate paths
93
  existing_paths = [p for p in sys.path if p.endswith("ChemQ3MTP")]
94
  for path in existing_paths[:-1]: # keep only the most recently added
95
  sys.path.remove(path)
96
 
97
  # Now import from local ChemQ3MTP folder
98
  from FastChemTokenizerHF import FastChemTokenizerSelfies
99
+ from ChemQ3MTP import ChemQ3MTPForCausalLM
100
 
101
  # --- Load from checkpoint (same as saved in training) ---
102
+ checkpoint_dir = "./" # or your actual checkpoint path
103
 
104
+ print(f"Loading tokenizer...")
105
+ tokenizer = FastChemTokenizerSelfies.from_pretrained('./selftok_core/')
106
 
107
  print(f"Loading ChemQ3MTP model from {checkpoint_dir}...")
108
  model = ChemQ3MTPForCausalLM.from_pretrained(checkpoint_dir)
 
140
  print(f"Generation failed: {e}")
141
  import traceback
142
  traceback.print_exc()
143
+
144
+ # Loading tokenizer...
145
+ # ✅ Special tokens bound: 0 1 2 3 4
146
+ # Loading ChemQ3MTP model from ./...
147
+ # Generated SELFIES:
148
+ # .[N] [C] [C] [N] [C] [C] [=C] [C] [=C] [Branch1] ...
149
  ```
150
 
151
+ **Generate and Visualize:**
152
 
153
  ```python
154
  # Generate Mol Viz
155
  from rdkit import Chem
156
  from rdkit.Chem import Draw
157
+ import selfies as sf
158
 
159
  input_ids = tokenizer("<s>", return_tensors="pt").input_ids.to(device)
160
+ gen = model.generate(input_ids, max_length=512, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)
161
  generatedmol = tokenizer.decode(gen[0], skip_special_tokens=True)
162
 
163
  test = generatedmol.replace(' ', '')
 
167
 
168
  # Draw the molecule
169
  Draw.MolToImage(mol)
170
+
171
+ # NC1=NC2=C(Br)C=CC=C2N1CCCCNCCC3=CC=CC(Cl)=C3
172
  ```
173
 
174
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/667da868d653c0b02d6a2399/Ro950Z7AVBGEXqfY5sV94.png)
175
+
176
+
177
  ---
178
 
179
  ## 📊 Model Architecture