File size: 2,163 Bytes
b8ffacb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import json
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM

load_dotenv()
HF_TOKEN = os.environ.get("HF_TOKEN")

tdc_prompts_filepath = hf_hub_download(
    repo_id="google/txgemma-2b-predict",
    filename="tdc_prompts.json",  
)

tokenizer = AutoTokenizer.from_pretrained(
    "google/txgemma-2b-predict", 
    token = HF_TOKEN
)

model = AutoModelForCausalLM.from_pretrained(
    "google/txgemma-2b-predict",
    device_map = "auto",
    token = HF_TOKEN
)

with open(tdc_prompts_filepath, "r") as f:
    tdc_prompts = json.load(f)

def txgemma_predict(prompt):
    input_ids = tokenizer(prompt, return_tensors="pt").to("cpu")
    outputs = model.generate(**input_ids)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def predict_kiba_score(drug_smile, amino_acid):
    TDC_PROMPT = tdc_prompts["KIBA"].replace("{Drug SMILES}", drug_smile).replace("{Target amino acid sequence}", amino_acid)
    response = txgemma_predict(TDC_PROMPT)
    return response.split("Answer:")[1].strip()

def predict(task, drug_smile, amino_acid=None):
    if task == "KIBA Score":
        if amino_acid is None:
            raise ValueError("amino_acid parameter is required for KIBA task")
        kiba_score = predict_kiba_score(drug_smile, amino_acid)
        return f"{kiba_score} Binding Affinity On Scale of 0-1000"


    if task == "Skin Reaction":
        TDC_PROMPT = tdc_prompts["Skin_Reaction"].replace("{Drug SMILES}", drug_smile)
        response = txgemma_predict(TDC_PROMPT).split("Answer:")[1].strip()

        if "(A)" in response: response = f"{drug_smile} does not cause a skin reaction!"
        elif "(B)" in response: response = f"{drug_smile} causes a skin reaction!"

        return response

    if task == "Liver Safety":
        TDC_PROMPT = tdc_prompts["DILI"].replace("{Drug SMILES}", drug_smile)
        response = txgemma_predict(TDC_PROMPT).split("Answer:")[1].strip()

        if "(A)" in response: response = f"{drug_smile} does not damage a liver!"
        elif "(B)" in response: response = f"{drug_smile} can damage a liver!"