jibsn commited on
Commit
3ecd782
·
verified ·
1 Parent(s): 87e937c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -18
app.py CHANGED
@@ -2,10 +2,13 @@ import gradio as gr
2
  import onnxruntime as ort
3
  import numpy as np
4
  from PIL import Image
 
 
5
  import io
6
  import rdkit
7
  from rdkit import Chem
8
  from rdkit.Chem import Draw
 
9
  from utils import bbox_to_graph_with_charge, mol_from_graph_with_chiral
10
 
11
  bond_labels = [13,14,15,16,17]
@@ -13,20 +16,26 @@ idx_to_labels = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
13
  9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH',
14
  16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'+2',} #NONE is single ?
15
 
16
- def preprocess_image(image):
17
- """
18
- 预处理输入图片
19
- """
20
- # 将图片调整为模型所需的输入尺寸
21
- image = image.resize((640, 640)) # 根据实际模型需求调整尺寸
22
- # 转换为numpy数组并归一化
23
- img_array = np.array(image)
24
- img_array = img_array.astype(np.float32) / 255.0
25
- # 添加批次维度
26
- img_array = np.expand_dims(img_array, axis=0)
27
- # 根据模型训练时的预处理方式进行调整
28
- img_array = img_array.transpose(0, 3, 1, 2) # BHWC to BCHW
29
- return img_array
 
 
 
 
 
 
30
 
31
  def visualize_molecule(smiles):
32
  """
@@ -50,17 +59,41 @@ def predict(input_image):
50
  session = ort.InferenceSession("model.onnx") # 替换为实际模型路径
51
 
52
  # 预处理图片
53
- processed_image = preprocess_image(input_image)
 
 
54
 
55
  # 获取模型输入输出名称
56
  input_name = session.get_inputs()[0].name
57
  output_name = session.get_outputs()[0].name
58
 
59
  # 进行推理
60
- predictions = session.run([output_name], {input_name: processed_image})
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- # 假设模型输出是SMILES字符串
63
- output = predictions # 根据实际模型输出格式调整
 
 
 
 
 
 
 
 
 
 
64
  atoms_df, bonds_list,charge_list =bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels,
65
  bond_labels=bond_labels, result=[])
66
  smiles,mol_rebuit=mol_from_graph_with_chiral(atoms_df, bonds_list,charge_list )
 
2
  import onnxruntime as ort
3
  import numpy as np
4
  from PIL import Image
5
+ from torchvision import transforms
6
+ import torchvision.transforms.v2 as T
7
  import io
8
  import rdkit
9
  from rdkit import Chem
10
  from rdkit.Chem import Draw
11
+ from postprocessor import RTDETRPostProcessor
12
  from utils import bbox_to_graph_with_charge, mol_from_graph_with_chiral
13
 
14
  bond_labels = [13,14,15,16,17]
 
16
  9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH',
17
  16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'+2',} #NONE is single ?
18
 
19
+
20
+ def image_to_tensor(image_path):
21
+ # Open the image using PIL
22
+ image = Image.open(image_path)
23
+ w, h = image.size
24
+ # print("width: {}, height: {}".format(w, h))
25
+ # Define a transform to convert the image to a tensor and normalize it
26
+ transform = transforms.Compose([
27
+ # transforms.Grayscale(num_output_channels=1), # Convert to grayscale (1 channel)
28
+ T.Resize((640, 640)), # Resize the image to 224x224
29
+ T.ToImageTensor(), # Convert to Tensor (C x H x W)
30
+ T.ConvertDtype(dtype=torch.float32)
31
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Optional normalization for pretrained models
32
+ ])
33
+
34
+ # Apply the transform to the image
35
+ tensor = transform(image)
36
+
37
+ return tensor,w,h
38
+
39
 
40
  def visualize_molecule(smiles):
41
  """
 
59
  session = ort.InferenceSession("model.onnx") # 替换为实际模型路径
60
 
61
  # 预处理图片
62
+ # Example usage: #change thie image
63
+ tensor,w,h = image_to_tensor(input_image)
64
+ processed_image=tensor.unsqueeze(0)
65
 
66
  # 获取模型输入输出名称
67
  input_name = session.get_inputs()[0].name
68
  output_name = session.get_outputs()[0].name
69
 
70
  # 进行推理
71
+ outputs = session.run([output_name], {input_name: processed_image})
72
+ ori_size=torch.Tensor([w,h]).long().unsqueeze(0)
73
+ postprocessor = RTDETRPostProcessor()
74
+ result_ = postprocessor(outputs, ori_size)
75
+ score_=result_[0]['scores']
76
+ boxe_=result_[0]['boxes']
77
+ label_=result_[0]['labels']
78
+ selected_indices =score_ > 0.5
79
+ output={
80
+ 'labels': label_[selected_indices],
81
+ 'boxes': boxe_[selected_indices],
82
+ 'scores': score_[selected_indices]
83
+ }
84
 
85
+ filtered_output_dict={image_path: output
86
+ }
87
+
88
+
89
+ x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
90
+ y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
91
+ center_coords = torch.stack((x_center, y_center), dim=1)
92
+ output = {'bbox': output["boxes"].to("cpu").numpy(),
93
+ 'bbox_centers': center_coords.to("cpu").numpy(),
94
+ 'scores': output["scores"].to("cpu").numpy(),
95
+ 'pred_classes': output["labels"].to("cpu").numpy()}
96
+
97
  atoms_df, bonds_list,charge_list =bbox_to_graph_with_charge(output, idx_to_labels=idx_to_labels,
98
  bond_labels=bond_labels, result=[])
99
  smiles,mol_rebuit=mol_from_graph_with_chiral(atoms_df, bonds_list,charge_list )