vhbert / app.py
rajaatif786's picture
Update app.py
5fe14a7 verified
import gradio as gr
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import os
from download_from_drive import download_file
from uuid import uuid4
# === Download Required Files from Google Drive ===
# download_file("1LauVdBy41kZvldZqK3dQIDApKdLqXiXs", "DNAEncoder.py")
# download_file("12kRAe9nmU-8k20Q32VZYFXkLTiUXFEZT", "Preprocessor.py")
# download_file("18erKly_wBTw_wu0y4eRfTEKtOL0Hs92k", "PretrainedBERT.py")
# download_file("1_bvtrRupabYwHSoXPOwChL-jsJJEj-vV", "inference.py")
# download_file("1BSmhgZr394cNMyvoij1zCNDcwe0QwAsn", "model.pt")
# === Import your modules ===
from DNAEncoder import ConvertDNALabelEncoder
from Preprocessor import PreprocessLLMData
from PretrainedBERT import initialize_pretrained_bert
from inference import inference
# === Model Loading ===
bert_classifier, optimizer = initialize_pretrained_bert(2)
bert_classifier.load_state_dict(torch.load("model.pt", map_location=torch.device("cpu")))
bert_classifier.eval()
# === DNA Encoder ===
convertDNALabelEncoder = ConvertDNALabelEncoder()
# === Prediction Function ===
def predict_dna(seq: str):
try:
seq = seq.replace('\n', '')
genome_length = len(seq)
chunk_size = 250
chunks = []
# Slice genome into 250bp chunks
for i in range(0, genome_length, chunk_size):
chunk_seq = seq[i:i+chunk_size]
# Create temp CSV for chunk
temp_csv = f"temp_input.csv"
df = pd.DataFrame({'seq': [chunk_seq], 'label': [0]})
df.to_csv(temp_csv, index=False)
# Encode sequence
df_encoded, y = convertDNALabelEncoder.convert_dna_string_to_dna_labelencoder(
temp_csv, 'seq', 'label'
)
os.remove(temp_csv)
# Preprocess
preprocessor = PreprocessLLMData(df_encoded[0], y[0])
_, _, y_list, test_dataloader = preprocessor.preprocess()
# Inference
probs, labels_list, _ = inference(bert_classifier, test_dataloader, device='cpu')
predicted_class = int(probs >= 0.5)
# Append chunk prediction
chunks.append({
"start": i,
"end": min(i + chunk_size, genome_length),
"score": float(probs),
"label": "pathogenic" if predicted_class == 1 else "nonpathogenic"
})
# Global summary (aggregate prediction)
avg_score = float(np.mean([c["score"] for c in chunks]))
majority_class = max(set([c["label"] for c in chunks]),
key=[c["label"] for c in chunks].count)
# Example GFF (replace with real annotations if available)
annotations_gff = (
"##gff-version 3\n"
"virus . gene 1 500 . + . ID=gene1;Name=ORF1\n"
"virus . gene 800 1500 . - . ID=gene2;Name=ORF2"
)
print(chunks[0])
# ✅ Flat JSON (like old version, but extended)
return {
"input_sequence": seq[:30]+"...",
"confidence": avg_score,
"label_name": "Human-Pathogenic" if majority_class == "pathogenic" else "Non-Human",
"chunks": chunks,
"genome_length": genome_length,
"annotations_gff": annotations_gff,
"actual": majority_class,
"predicted": majority_class,
}
except Exception as e:
return {"error": str(e)}
# === Gradio Interface ===
api = gr.Interface(
fn=predict_dna,
inputs=gr.Textbox(label="DNA Sequence"),
outputs="json",
api_name="/predict_dna"
# allow_flagging="never"
)
api.queue(api_open=True)
# with gr.Blocks() as demo:
# def fn(a: int, b: int, c: list[str]) -> tuple[int, str]:
# return a + b, c[a:b]
# gr.api(fn, api_name="add_and_slice")
if __name__ == "__main__":
api.launch( share=True,
debug=True,
show_api=True,
ssr_mode=False) #share=True, debug=True, show_api=True) show_api=True,share=True,ssr_mode=False
# def add_and_slice(a: int, b: int, c: list[int]) -> tuple[int, list[int]]:
# return a + b, c[a:b]
# # === Gradio App ===
# with gr.Blocks() as demo:
# # Visual UI for predict_dna
# gr.Interface(
# fn=predict_dna,
# inputs=gr.Textbox(label="DNA Sequence"),
# outputs="json",
# allow_flagging="never"
# )
# # REST API for add_and_slice
# demo.api(fn=add_and_slice,
# inputs=[gr.Number(), gr.Number(), gr.Textbox()],
# outputs=["number", "json"],
# api_name="/add_and_slice")
# if __name__ == "__main__":
# demo.launch(share=True, show_api=True)
# import gradio as gr
# import torch
# import pandas as pd
# import numpy as np
# from download_from_drive import download_file
# # === Download Required Files from Google Drive ===
# download_file("1LauVdBy41kZvldZqK3dQIDApKdLqXiXs", "DNAEncoder.py")
# download_file("1C_c5Zf074PEh0YD3srurZt8FKz-tAgUB", "Preprocessor.py")
# download_file("18erKly_wBTw_wu0y4eRfTEKtOL0Hs92k", "PretrainedBERT.py")
# download_file("1q4NBD4dfx2xoZQgyOyLfMsOgv2ByUv-y", "inference.py")
# download_file("1BSmhgZr394cNMyvoij1zCNDcwe0QwAsn", "model.pt")
# # === Import your modules ===
# from DNAEncoder import ConvertDNALabelEncoder
# from Preprocessor import PreprocessLLMData
# from PretrainedBERT import initialize_pretrained_bert
# from inference import inference
# # === Model Loading ===
# bert_classifier, optimizer = initialize_pretrained_bert(2)
# bert_classifier.load_state_dict(torch.load("model.pt", map_location=torch.device("cpu")), strict=False)
# bert_classifier.eval()
# # === DNA Encoder ===
# convertDNALabelEncoder = ConvertDNALabelEncoder()
# # === Prediction Function ===
# def predict_dna(seq):
# try:
# # Create DataFrame for input sequence
# df = pd.DataFrame({'seq': [seq], 'label': [0]}) # dummy label
# # Apply Label Encoding
# df_encoded, y = convertDNALabelEncoder.convert_dna_string_to_dna_labelencoder(df, 'seq', 'label')
# # Preprocess
# preprocessor = PreprocessLLMData(df_encoded[0], y[0])
# inputs_list, masks_list, y_list, test_dataloader = preprocessor.preprocess()
# # Run inference
# probs, labels_list, logits = inference(bert_classifier, test_dataloader, device='cpu')
# predicted_class = int(probs >= 0.5)
# actual_class = int(np.argmax(y_list[0]))
# probability = float(probs)
# return {
# "input_sequence": seq,
# "actual": actual_class,
# "predicted": predicted_class,
# "confidence": probability,
# "label_name": "Human-Pathogenic" if predicted_class == 1 else "Non-Human"
# }
# except Exception as e:
# return {"error": str(e)}
# # === Gradio Interface ===
# api = gr.Interface(
# fn=predict_dna,
# inputs=gr.Textbox(label="DNA Sequence"),
# outputs="json",
# )
# if __name__ == "__main__":
# api.launch(server_name="0.0.0.0", server_port=7860, show_api=True)