SOUMYADIP MAL commited on
Commit
568422e
·
1 Parent(s): 346f4a4

commiting the meme classification hf demo

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ scripts_and_models/efficientNet_clf.pt filter=lfs diff=lfs merge=lfs -text
37
+ example_imgs/* filter=lfs diff=lfs merge=lfs -text
38
+ example_imgs/*.jpg filter=lfs diff=lfs merge=lfs -text
39
+ example_imgs/meme.png filter=lfs diff=lfs merge=lfs -text
40
+ example_imgs/non-meme.jpg filter=lfs diff=lfs merge=lfs -text
example_imgs/meme.png ADDED

Git LFS Details

  • SHA256: e49fb5f664c2be99a8b2a209478cc961d9a99728556bf19723d4869461e8642e
  • Pointer size: 131 Bytes
  • Size of remote file: 568 kB
example_imgs/non-meme.jpg ADDED

Git LFS Details

  • SHA256: 1cccef217a1e5eec2acc74d7c08a105da3ef6fdb8514e221cdea9b0b373b60ea
  • Pointer size: 130 Bytes
  • Size of remote file: 83 kB
scripts_and_models/app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ from pathlib import Path
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple, Dict
8
+ from torchvision import transforms
9
+
10
+ class_names=['meme', 'non-meme']
11
+
12
+
13
+ model_path=Path("efficientNet_clf.pt")
14
+ model = torch.jit.load(model_path)
15
+ image_transform = transforms.Compose([
16
+ transforms.Resize((224,224)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
19
+ std=[0.229, 0.224, 0.225]),
20
+ ])
21
+ print(image_transform)
22
+
23
+ def predict(img) -> Tuple[Dict, float]:
24
+ """Transforms and performs a prediction on img and returns prediction and time taken.
25
+ """
26
+
27
+ print("---img path is: ",img)
28
+ start_time = timer()
29
+ model.to("cpu")
30
+ model.eval()
31
+ with torch.inference_mode():
32
+ img = image_transform(img).unsqueeze(dim=0)
33
+ pred_probs = torch.softmax(model(img).to("cpu"), dim=1)
34
+
35
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
36
+ pred_time = round(timer() - start_time, 5)
37
+
38
+ return pred_labels_and_probs, pred_time
39
+
40
+ #print(e)
41
+ #return "error",0
42
+
43
+
44
+
45
+ title = "Meme classifiication"
46
+ description = "An EfficientNetB2 model to classify images of food into 2 classes:meme and non-meme"
47
+
48
+
49
+ example_list = ["../example_imgs/"+i for i in os.listdir("../example_imgs")]
50
+ #print(example_list)
51
+
52
+ demo = gr.Interface(
53
+ fn=predict,
54
+ inputs=gr.Image(type="pil"),
55
+ outputs=[
56
+ gr.Label(num_top_classes=2, label="Predictions"),
57
+ gr.Number(label="Prediction time (s)"),
58
+ ],
59
+ examples=example_list,
60
+ title=title,
61
+ description=description,
62
+ )
63
+
64
+ demo.launch()
65
+ #predict(example_list[0])
scripts_and_models/efficientNet_clf.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00aa3d1e2f5828f9529424a021577e181b779d77fc95a47ecc3d9f562d3b9b7e
3
+ size 16535370
scripts_and_models/inference.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from PIL import Image
3
+ import torch
4
+ import torchvision
5
+ from torchvision import datasets, transforms
6
+ import matplotlib.pyplot as plt
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+
11
+
12
+ def pred_and_plot_image(model: torch.nn.Module,
13
+ image_path: str,
14
+ class_names: List[str],
15
+ image_size: Tuple[int, int] = (224, 224),
16
+ transform: torchvision.transforms = None,
17
+ device: torch.device=device):
18
+
19
+
20
+ img = Image.open(image_path)
21
+
22
+ if transform is not None:
23
+ image_transform = transform
24
+ else:
25
+ image_transform = transforms.Compose([
26
+ transforms.Resize(image_size),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
29
+ std=[0.229, 0.224, 0.225]),
30
+ ])
31
+
32
+
33
+ model.to(device)
34
+ model.eval()
35
+ with torch.inference_mode():
36
+ transformed_image = image_transform(img).unsqueeze(dim=0)
37
+ target_image_pred = model(transformed_image.to(device))
38
+ target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
39
+ target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
40
+
41
+ plt.figure()
42
+ plt.imshow(img)
43
+ plt.title(f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}")
44
+ plt.axis(False);
45
+ plt.show()
46
+
47
+
48
+
49
+
50
+ from pathlib import Path
51
+
52
+ model_path=Path("efficientNet_clf.pt")
53
+ print(model_path)
54
+
55
+ model = torch.jit.load(model_path)
56
+
57
+ class_names=['meme', 'non-meme']
58
+
59
+ pred_and_plot_image(model=model,
60
+ image_path="../example_imgs/meme.png",
61
+ class_names=class_names)