| from metano_inference import load_model_from_hf, predict_neurosymbolic, SymbolicScorer, ModelConfig |
|
|
| def main(): |
| print("Loading local METANO model...") |
| |
| local_model_path = "banulaperera/metano" |
| model = load_model_from_hf(local_model_path) |
|
|
| print("Initializing Symbolic Scorer...") |
| config = ModelConfig() |
| scorer = SymbolicScorer(metals=config.metal_elements) |
|
|
| |
| test_inchi = "InChI=1/Fe.Na.H2O4S.H2O.H/c;;1-5(2,3)4;;/h;;(H2,1,2,3,4);1H2;/q;+1;;;-1" |
| |
| print(f"\nRunning prediction for InChI:\n{test_inchi}\n") |
| |
| out = predict_neurosymbolic( |
| model=model, |
| inchi=test_inchi, |
| scorer=scorer, |
| num_candidates=5, |
| repair_num_candidates=5, |
| max_repair_rounds=1 |
| ) |
|
|
| print("=== TEST RESULTS ===") |
| print(f"Predicted IUPAC: {out['predicted_iupac']}") |
| print(f"Hard Fail Triggered: {out['hard_fail']}") |
| print(f"Combined Score: {out['combined_score']:.3f}") |
| print(f"Symbolic Score: {out['symbolic_score']:.3f}") |
| print(f"Neural Score: {out['neural_score']:.3f}") |
| |
| if out['reasons']: |
| print(f"Penalty Reasons: {out['reasons']}") |
| |
| print("\nTop Candidates:") |
| for cand in out["candidates"][1:]: |
| print(f" [{cand['combined']:.3f}] {cand['text']}") |
|
|
| if __name__ == "__main__": |
| main() |