SOUMYADIP MAL commited on
Commit
a647579
Β·
1 Parent(s): 568422e

changing the dir struct

Browse files
scripts_and_models/app.py β†’ app.py RENAMED
@@ -24,7 +24,7 @@ def predict(img) -> Tuple[Dict, float]:
24
  """Transforms and performs a prediction on img and returns prediction and time taken.
25
  """
26
 
27
- print("---img path is: ",img)
28
  start_time = timer()
29
  model.to("cpu")
30
  model.eval()
@@ -46,7 +46,7 @@ title = "Meme classifiication"
46
  description = "An EfficientNetB2 model to classify images of food into 2 classes:meme and non-meme"
47
 
48
 
49
- example_list = ["../example_imgs/"+i for i in os.listdir("../example_imgs")]
50
  #print(example_list)
51
 
52
  demo = gr.Interface(
 
24
  """Transforms and performs a prediction on img and returns prediction and time taken.
25
  """
26
 
27
+ #print("---img path is: ",img)
28
  start_time = timer()
29
  model.to("cpu")
30
  model.eval()
 
46
  description = "An EfficientNetB2 model to classify images of food into 2 classes:meme and non-meme"
47
 
48
 
49
+ example_list = ["./example_imgs/"+i for i in os.listdir("./example_imgs")]
50
  #print(example_list)
51
 
52
  demo = gr.Interface(
scripts_and_models/efficientNet_clf.pt β†’ efficientNet_clf.pt RENAMED
File without changes
scripts_and_models/inference.py DELETED
@@ -1,61 +0,0 @@
1
- from typing import List, Tuple
2
- from PIL import Image
3
- import torch
4
- import torchvision
5
- from torchvision import datasets, transforms
6
- import matplotlib.pyplot as plt
7
-
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
-
11
-
12
- def pred_and_plot_image(model: torch.nn.Module,
13
- image_path: str,
14
- class_names: List[str],
15
- image_size: Tuple[int, int] = (224, 224),
16
- transform: torchvision.transforms = None,
17
- device: torch.device=device):
18
-
19
-
20
- img = Image.open(image_path)
21
-
22
- if transform is not None:
23
- image_transform = transform
24
- else:
25
- image_transform = transforms.Compose([
26
- transforms.Resize(image_size),
27
- transforms.ToTensor(),
28
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
29
- std=[0.229, 0.224, 0.225]),
30
- ])
31
-
32
-
33
- model.to(device)
34
- model.eval()
35
- with torch.inference_mode():
36
- transformed_image = image_transform(img).unsqueeze(dim=0)
37
- target_image_pred = model(transformed_image.to(device))
38
- target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
39
- target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
40
-
41
- plt.figure()
42
- plt.imshow(img)
43
- plt.title(f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}")
44
- plt.axis(False);
45
- plt.show()
46
-
47
-
48
-
49
-
50
- from pathlib import Path
51
-
52
- model_path=Path("efficientNet_clf.pt")
53
- print(model_path)
54
-
55
- model = torch.jit.load(model_path)
56
-
57
- class_names=['meme', 'non-meme']
58
-
59
- pred_and_plot_image(model=model,
60
- image_path="../example_imgs/meme.png",
61
- class_names=class_names)