scruzlara commited on
Commit
abe6273
·
verified ·
1 Parent(s): fbc714c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -46
app.py CHANGED
@@ -7,68 +7,57 @@ import numpy as np
7
  import os
8
  import random
9
 
10
- # Device configuration
11
  device = torch.device('cpu')
12
 
13
- # Disease labels
14
  labels = {
15
- 0: 'bacterial_leaf_blight',
16
- 1: 'bacterial_leaf_streak',
17
- 2: 'bacterial_panicle_blight',
18
- 3: 'blast',
19
- 4: 'brown_spot',
20
- 5: 'dead_heart',
21
- 6: 'downy_mildew',
22
- 7: 'hispa',
23
- 8: 'normal',
24
- 9: 'tungro'
25
- }
26
-
27
- def inference_fn(model, image=None):
28
  model.eval()
29
- image = image.to(device)
30
  with torch.no_grad():
31
  output = model(image.unsqueeze(0))
32
- out = output.sigmoid().detach().cpu().numpy().flatten()
33
  return out
34
-
35
- def predict(image=None) -> dict:
36
- if image is None:
37
- return {label: 0.0 for label in labels.values()}
38
-
39
- # Image preprocessing
40
- mean = (0.485, 0.456, 0.406)
41
- std = (0.229, 0.224, 0.225)
42
 
43
- augmentations = albumentations.Compose([
44
- albumentations.Resize(256, 256),
45
- albumentations.HorizontalFlip(p=0.5),
46
- albumentations.VerticalFlip(p=0.5),
47
- albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
48
- ])
 
 
 
 
 
 
 
49
 
50
  augmented = augmentations(image=image)
51
  image = augmented["image"]
52
  image = np.transpose(image, (2, 0, 1))
53
  image = torch.tensor(image, dtype=torch.float32)
54
-
55
- # Model initialization
56
  model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
57
  model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device)))
58
  model.to(device)
59
-
60
- # Make prediction
61
  predicted = inference_fn(model, image)
 
62
  return {labels[i]: float(predicted[i]) for i in range(10)}
 
63
 
64
- # Create and launch Gradio interface with API support
65
- demo = gr.Interface(
66
- fn=predict,
67
- inputs=gr.Image(),
68
- outputs=gr.Label(num_top_classes=10),
69
- examples=["200005.jpg", "200006.jpg"],
70
- interpretation='default'
71
- )
72
-
73
- # Launch with API enabled for Hugging Face Spaces
74
- demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_api=True)
 
7
  import os
8
  import random
9
 
 
10
  device = torch.device('cpu')
11
 
 
12
  labels = {
13
+ 0: 'bacterial_leaf_blight',
14
+ 1: 'bacterial_leaf_streak',
15
+ 2: 'bacterial_panicle_blight',
16
+ 3: 'blast',
17
+ 4: 'brown_spot',
18
+ 5: 'dead_heart',
19
+ 6: 'downy_mildew',
20
+ 7: 'hispa',
21
+ 8: 'normal',
22
+ 9: 'tungro'
23
+ }
24
+
25
+ def inference_fn(model, image=None):
26
  model.eval()
27
+ image = image.to(device)
28
  with torch.no_grad():
29
  output = model(image.unsqueeze(0))
30
+ out = output.sigmoid().detach().cpu().numpy().flatten()
31
  return out
 
 
 
 
 
 
 
 
32
 
33
+
34
+ def predict(image=None) -> dict:
35
+ mean = (0.485, 0.456, 0.406)
36
+ std = (0.229, 0.224, 0.225)
37
+
38
+ augmentations = albumentations.Compose(
39
+ [
40
+ albumentations.Resize(256, 256),
41
+ albumentations.HorizontalFlip(p=0.5),
42
+ albumentations.VerticalFlip(p=0.5),
43
+ albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
44
+ ]
45
+ )
46
 
47
  augmented = augmentations(image=image)
48
  image = augmented["image"]
49
  image = np.transpose(image, (2, 0, 1))
50
  image = torch.tensor(image, dtype=torch.float32)
 
 
51
  model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
52
  model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device)))
53
  model.to(device)
54
+
 
55
  predicted = inference_fn(model, image)
56
+
57
  return {labels[i]: float(predicted[i]) for i in range(10)}
58
+
59
 
60
+ gr.Interface(fn=predict,
61
+ inputs=gr.inputs.Image(),
62
+ outputs=gr.outputs.Label(num_top_classes=10),
63
+ examples=["200005.jpg", "200006.jpg"], interpretation='default').launch()