Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| from dotenv import load_dotenv | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| load_dotenv() | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| tdc_prompts_filepath = hf_hub_download( | |
| repo_id="google/txgemma-2b-predict", | |
| filename="tdc_prompts.json", | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "google/txgemma-2b-predict", | |
| token = HF_TOKEN | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "google/txgemma-2b-predict", | |
| device_map = "auto", | |
| token = HF_TOKEN | |
| ) | |
| with open(tdc_prompts_filepath, "r") as f: | |
| tdc_prompts = json.load(f) | |
| def txgemma_predict(prompt): | |
| input_ids = tokenizer(prompt, return_tensors="pt").to("cpu") | |
| outputs = model.generate(**input_ids) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def predict_kiba_score(drug_smile, amino_acid): | |
| TDC_PROMPT = tdc_prompts["KIBA"].replace("{Drug SMILES}", drug_smile).replace("{Target amino acid sequence}", amino_acid) | |
| response = txgemma_predict(TDC_PROMPT) | |
| return response.split("Answer:")[1].strip() | |
| def predict(task, drug_smile, amino_acid=None): | |
| if task == "KIBA Score": | |
| if amino_acid is None: | |
| raise ValueError("amino_acid parameter is required for KIBA task") | |
| kiba_score = predict_kiba_score(drug_smile, amino_acid) | |
| return f"{kiba_score} Binding Affinity On Scale of 0-1000" | |
| if task == "Skin Reaction": | |
| TDC_PROMPT = tdc_prompts["Skin_Reaction"].replace("{Drug SMILES}", drug_smile) | |
| response = txgemma_predict(TDC_PROMPT).split("Answer:")[1].strip() | |
| if "(A)" in response: response = f"{drug_smile} does not cause a skin reaction!" | |
| elif "(B)" in response: response = f"{drug_smile} causes a skin reaction!" | |
| return response | |
| if task == "Liver Safety": | |
| TDC_PROMPT = tdc_prompts["DILI"].replace("{Drug SMILES}", drug_smile) | |
| response = txgemma_predict(TDC_PROMPT).split("Answer:")[1].strip() | |
| if "(A)" in response: response = f"{drug_smile} does not damage a liver!" | |
| elif "(B)" in response: response = f"{drug_smile} can damage a liver!" | |