Drug-Gemma / src /model.py
Jay Prajapati
v1.0.0
b8ffacb
import os
import json
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM
load_dotenv()
HF_TOKEN = os.environ.get("HF_TOKEN")
tdc_prompts_filepath = hf_hub_download(
repo_id="google/txgemma-2b-predict",
filename="tdc_prompts.json",
)
tokenizer = AutoTokenizer.from_pretrained(
"google/txgemma-2b-predict",
token = HF_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(
"google/txgemma-2b-predict",
device_map = "auto",
token = HF_TOKEN
)
with open(tdc_prompts_filepath, "r") as f:
tdc_prompts = json.load(f)
def txgemma_predict(prompt):
input_ids = tokenizer(prompt, return_tensors="pt").to("cpu")
outputs = model.generate(**input_ids)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def predict_kiba_score(drug_smile, amino_acid):
TDC_PROMPT = tdc_prompts["KIBA"].replace("{Drug SMILES}", drug_smile).replace("{Target amino acid sequence}", amino_acid)
response = txgemma_predict(TDC_PROMPT)
return response.split("Answer:")[1].strip()
def predict(task, drug_smile, amino_acid=None):
if task == "KIBA Score":
if amino_acid is None:
raise ValueError("amino_acid parameter is required for KIBA task")
kiba_score = predict_kiba_score(drug_smile, amino_acid)
return f"{kiba_score} Binding Affinity On Scale of 0-1000"
if task == "Skin Reaction":
TDC_PROMPT = tdc_prompts["Skin_Reaction"].replace("{Drug SMILES}", drug_smile)
response = txgemma_predict(TDC_PROMPT).split("Answer:")[1].strip()
if "(A)" in response: response = f"{drug_smile} does not cause a skin reaction!"
elif "(B)" in response: response = f"{drug_smile} causes a skin reaction!"
return response
if task == "Liver Safety":
TDC_PROMPT = tdc_prompts["DILI"].replace("{Drug SMILES}", drug_smile)
response = txgemma_predict(TDC_PROMPT).split("Answer:")[1].strip()
if "(A)" in response: response = f"{drug_smile} does not damage a liver!"
elif "(B)" in response: response = f"{drug_smile} can damage a liver!"