molecular_Designer_Env / server /molecular_Designer_Env_environment.py
Mrkumar007's picture
Upload folder using huggingface_hub
2106752 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Molecular Designer Environment Implementation.
A highly realistic molecular design environment using RDKit for chemical property feedback.
"""
import os
import math
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors
from rdkit.Chem import QED
except ImportError:
# Handle gracefully if RDKit is not installed locally but might be in container
Chem = None
Descriptors = None
rdMolDescriptors = None
QED = None
try:
from ..models import MolecularDesignerEnvAction, MolecularDesignerEnvObservation
except ImportError:
from models import MolecularDesignerEnvAction, MolecularDesignerEnvObservation
class MolecularDesignerEnvEnvironment(Environment):
"""
Molecular Designer environment powered by RDKit.
Evaluates generated SMILES across 3 difficulty tasks.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
MAX_STEPS = 10
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self._reset_count = 0
self.task_name = os.getenv("TASK_NAME", "easy")
self.history = set()
def reset(self) -> MolecularDesignerEnvObservation:
self._state = State(episode_id=str(uuid4()), step_count=0)
self._reset_count += 1
self.history = set()
return MolecularDesignerEnvObservation(
is_valid=False,
mw=0.0,
logp=0.0,
qed=0.0,
sas=0.0,
feedback=f"Molecular Designer initialized. Task: {self.task_name}. Submit a SMILES string.",
done=False,
reward=0.0,
)
def _gaussian_reward(self, val: float, target: float, sigma: float) -> float:
return math.exp(-((val - target) ** 2) / (2 * sigma ** 2))
def evaluate_molecule(self, smiles: str):
if not Chem:
return False, 0.0, 0.0, 0.0, 0.0, "RDKit not installed."
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return False, 0.0, 0.0, 0.0, 0.0, "Invalid SMILES string."
mw = Descriptors.MolWt(mol)
logp = Descriptors.MolLogP(mol)
qed = QED.qed(mol)
# Simplified proxy for Synthetic Accessibility Score
n_rings = Descriptors.RingCount(mol)
sas = float(n_rings) + (mw / 100.0)
return True, mw, logp, qed, sas, "Valid molecule."
def calculate_reward(self, is_valid: bool, smiles: str, mw: float, logp: float, qed: float, sas: float):
"""Task-specific reward calculation (0.0 to 1.0 continuously scaled)"""
if not is_valid:
return -1.0
if self.task_name == "easy":
# Task 1: Generate a molecule within specific MW range (300-350)
target_mw = 325.0
return self._gaussian_reward(mw, target_mw, 50.0)
elif self.task_name == "medium":
# Task 2: Rule of 5 adherence (MW < 500, LogP < 5) + moderate QED
reward_mw = 1.0 if mw <= 500 else max(0.0, 1.0 - (mw - 500)/100.0)
reward_logp = 1.0 if logp <= 5.0 else max(0.0, 1.0 - (logp - 5.0))
return (reward_mw * 0.4) + (reward_logp * 0.4) + (qed * 0.2)
elif self.task_name == "hard":
# Task 3: High QED (> 0.8), target LogP (2.0 to 3.0). Hardest multi-objective.
reward_logp = self._gaussian_reward(logp, 2.5, 1.0)
final_score = (qed * 0.7) + (reward_logp * 0.3)
return final_score
return 0.0
def step(self, action: MolecularDesignerEnvAction) -> MolecularDesignerEnvObservation: # type: ignore[override]
self._state.step_count += 1
smiles = action.smiles.strip()
if smiles in self.history:
return MolecularDesignerEnvObservation(
is_valid=False, mw=0.0, logp=0.0, qed=0.0, sas=0.0,
feedback=f"Duplicate SMILES: {smiles} already tried. Try a new structure.",
done=self._state.step_count >= self.MAX_STEPS,
reward=-0.5, # Penalize repeating
)
self.history.add(smiles)
is_valid, mw, logp, qed, sas, msg = self.evaluate_molecule(smiles)
reward = self.calculate_reward(is_valid, smiles, mw, logp, qed, sas)
if not is_valid:
done = self._state.step_count >= self.MAX_STEPS
return MolecularDesignerEnvObservation(
is_valid=False, mw=0.0, logp=0.0, qed=0.0, sas=0.0,
feedback=msg,
done=done,
reward=reward,
)
# Determine success threshold appropriately based on task difficulty
success_threshold = 0.95 if self.task_name == "easy" else 0.85
done = reward >= success_threshold or self._state.step_count >= self.MAX_STEPS
detailed_feedback = (f"Valid: {smiles} | MW: {mw:.2f} | LogP: {logp:.2f} | "
f"QED: {qed:.2f} | SAS Proxy: {sas:.2f} | Reward: {reward:.3f}")
return MolecularDesignerEnvObservation(
is_valid=is_valid,
mw=mw,
logp=logp,
qed=qed,
sas=sas,
feedback=detailed_feedback,
done=done,
reward=reward,
)
@property
def state(self) -> State:
return self._state