ttoosi commited on
Commit
bdecc48
·
verified ·
1 Parent(s): 3de0f2a

Update app.py

Browse files

Omitted ds completely

Files changed (1) hide show
  1. app.py +19 -19
app.py CHANGED
@@ -9,11 +9,11 @@ from PIL import Image
9
  import numpy as np
10
  import random
11
 
12
- from datasets import load_dataset
13
- from datasets import DatasetDict
14
- ds = DatasetDict({
15
- "validation": load_dataset("chronopt-research/cropped-vggface2-224", split="validation"),
16
- })
17
 
18
 
19
 
@@ -49,22 +49,22 @@ preprocess = transforms.Compose([
49
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # vggface2
50
  ])
51
 
52
- # Function to make predictions
53
- def predict(image):
54
- if isinstance(image, np.ndarray):
55
- image = Image.fromarray(image) # Convert to PIL Image if i
56
- image = preprocess(image).unsqueeze(0) # Add batch dimension
57
- with torch.no_grad():
58
- output = model(image) # Perform inference on CPU
59
- _, predicted_class = output.max(1)
60
- # Fetch 9 random samples from the predicted class
61
- class_samples = ds.filter(lambda example: example['label'] == predicted_class.item())
62
-
63
- sample_images = random.sample(list(class_samples), min(len(class_samples), 9))
64
 
65
- sample_images_urls = [sample['image'] for sample in sample_images]
66
 
67
- return f"Predicted class: {predicted_class.item()}", sample_images_urls
68
 
69
 
70
  # Simplified Generative Inference
 
9
  import numpy as np
10
  import random
11
 
12
+ # from datasets import load_dataset
13
+ # from datasets import DatasetDict
14
+ # ds = DatasetDict({
15
+ # "validation": load_dataset("chronopt-research/cropped-vggface2-224", split="validation"),
16
+ # })
17
 
18
 
19
 
 
49
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # vggface2
50
  ])
51
 
52
+ # # Function to make predictions
53
+ # def predict(image):
54
+ # if isinstance(image, np.ndarray):
55
+ # image = Image.fromarray(image) # Convert to PIL Image if i
56
+ # image = preprocess(image).unsqueeze(0) # Add batch dimension
57
+ # with torch.no_grad():
58
+ # output = model(image) # Perform inference on CPU
59
+ # _, predicted_class = output.max(1)
60
+ # # Fetch 9 random samples from the predicted class
61
+ # class_samples = ds.filter(lambda example: example['label'] == predicted_class.item())
62
+
63
+ # sample_images = random.sample(list(class_samples), min(len(class_samples), 9))
64
 
65
+ # sample_images_urls = [sample['image'] for sample in sample_images]
66
 
67
+ # return f"Predicted class: {predicted_class.item()}", sample_images_urls
68
 
69
 
70
  # Simplified Generative Inference