Safi029 commited on
Commit
5e1593e
Β·
verified Β·
1 Parent(s): 3e286ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -8
app.py CHANGED
@@ -5,34 +5,49 @@ from PIL import Image
5
  import torch
6
  import torch.serialization
7
 
8
- # βœ… Add required classes for safe unpickling
9
- from ultralytics.nn.tasks import DetectionModel
10
  from ultralytics.nn.modules.conv import Conv
11
- from torch.nn import Sequential, Conv2d, BatchNorm2d
 
12
 
 
13
  torch.serialization.add_safe_globals([
14
- DetectionModel, Sequential, Conv, Conv2d, BatchNorm2d
 
 
 
 
 
 
 
 
 
 
15
  ])
16
 
17
- # πŸ“₯ Load model from Hugging Face
18
  model_path = hf_hub_download(repo_id="Safi029/ABD-model", filename="ABD.pt")
 
 
19
  model = YOLO(model_path)
20
 
21
- # 🧠 Define detection function
22
  def detect_structure(image):
23
  results = model(image)
24
  return Image.fromarray(results[0].plot())
25
 
26
- # πŸŽ›οΈ Build Gradio interface
27
  demo = gr.Interface(
28
  fn=detect_structure,
29
  inputs=gr.Image(type="pil"),
30
  outputs=gr.Image(type="pil"),
31
  title="YOLO Molecular Detector",
32
- description="Upload a molecule image. YOLOv8 detects atoms or bonds."
33
  )
34
 
35
  demo.launch()
36
 
37
 
38
 
 
 
5
  import torch
6
  import torch.serialization
7
 
8
+ # βœ… Import all model components that may appear in custom YOLOv8 models
9
+ from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU
10
  from ultralytics.nn.modules.conv import Conv
11
+ from ultralytics.nn.modules.block import Bottleneck, C2f, SPPF
12
+ from ultralytics.nn.tasks import DetectionModel
13
 
14
+ # βœ… Add all components to safe globals
15
  torch.serialization.add_safe_globals([
16
+ DetectionModel,
17
+ Sequential,
18
+ Conv,
19
+ Conv2d,
20
+ BatchNorm2d,
21
+ SiLU,
22
+ ReLU,
23
+ LeakyReLU,
24
+ Bottleneck,
25
+ C2f,
26
+ SPPF,
27
  ])
28
 
29
+ # πŸ“₯ Download model
30
  model_path = hf_hub_download(repo_id="Safi029/ABD-model", filename="ABD.pt")
31
+
32
+ # βœ… Load model safely
33
  model = YOLO(model_path)
34
 
35
+ # 🧠 Inference function
36
  def detect_structure(image):
37
  results = model(image)
38
  return Image.fromarray(results[0].plot())
39
 
40
+ # πŸŽ›οΈ Gradio UI
41
  demo = gr.Interface(
42
  fn=detect_structure,
43
  inputs=gr.Image(type="pil"),
44
  outputs=gr.Image(type="pil"),
45
  title="YOLO Molecular Detector",
46
+ description="Upload a molecular structure image. YOLOv8 will detect atoms and bonds."
47
  )
48
 
49
  demo.launch()
50
 
51
 
52
 
53
+