AI-Manith commited on
Commit
ca41343
·
verified ·
1 Parent(s): e3fa2f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -17,34 +17,39 @@ def load_model_and_mtcnn(model_path):
17
  def preprocess_image(image, mtcnn, device):
18
  processed_image = image # Initialize with the original image
19
  try:
20
- # The return_image parameter of MTCNN's forward method can return the original image along with detected faces, but here we directly pass the image
21
- cropped_faces, _ = mtcnn(image, return_image=True)
22
  if cropped_faces is not None and len(cropped_faces) > 0:
23
  processed_image = cropped_faces[0] # Use the first detected face
24
- # No else clause needed; if no faces detected, processed_image remains the original
25
  except Exception as e:
26
  st.write(f"Exception in face detection: {e}")
27
  processed_image = image
28
 
29
  transform = transforms.Compose([
30
- transforms.Resize((224, 224)),
31
  transforms.ToTensor(),
32
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
  ])
34
- image_tensor = transform(processed_image).to(device)
 
 
 
 
 
35
  image_tensor = image_tensor.unsqueeze(0) # Add a batch dimension
36
- return image_tensor, processed_image
37
 
38
  # Function for inference
39
  def predict(image_tensor, model, device):
40
  model.eval()
41
  with torch.no_grad():
42
  outputs = model(image_tensor)
 
43
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
44
  predicted_class = torch.argmax(probabilities, dim=1)
45
  return predicted_class, probabilities
46
 
47
- # Streamlit UI
48
  st.title("Face Detection and Classification with ViT")
49
  st.write("Upload an image, and the model will detect faces and classify the image.")
50
 
@@ -58,6 +63,7 @@ if uploaded_file is not None:
58
  image_tensor, final_image = preprocess_image(image, mtcnn, device)
59
  predicted_class, probabilities = predict(image_tensor, model, device)
60
 
 
61
  st.write(f"Predicted class: {predicted_class.item()}")
62
  # Display the final processed image
63
  st.image(final_image, caption='Processed Image', use_column_width=True)
 
17
  def preprocess_image(image, mtcnn, device):
18
  processed_image = image # Initialize with the original image
19
  try:
20
+ # Directly call mtcnn with the image to get cropped faces
21
+ cropped_faces = mtcnn(image)
22
  if cropped_faces is not None and len(cropped_faces) > 0:
23
  processed_image = cropped_faces[0] # Use the first detected face
 
24
  except Exception as e:
25
  st.write(f"Exception in face detection: {e}")
26
  processed_image = image
27
 
28
  transform = transforms.Compose([
29
+ transforms.Resize((224, 224)),
30
  transforms.ToTensor(),
31
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
  ])
33
+ # Ensure processed_image is a PIL Image for the transformation
34
+ if not isinstance(processed_image, Image.Image):
35
+ processed_image_pil = Image.fromarray(processed_image.cpu().numpy().astype('uint8'), 'RGB')
36
+ else:
37
+ processed_image_pil = processed_image
38
+ image_tensor = transform(processed_image_pil).to(device)
39
  image_tensor = image_tensor.unsqueeze(0) # Add a batch dimension
40
+ return image_tensor, processed_image_pil
41
 
42
  # Function for inference
43
  def predict(image_tensor, model, device):
44
  model.eval()
45
  with torch.no_grad():
46
  outputs = model(image_tensor)
47
+ # Adjust for your model's output if it does not have a 'logits' attribute
48
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
49
  predicted_class = torch.argmax(probabilities, dim=1)
50
  return predicted_class, probabilities
51
 
52
+ # Streamlit UI setup
53
  st.title("Face Detection and Classification with ViT")
54
  st.write("Upload an image, and the model will detect faces and classify the image.")
55
 
 
63
  image_tensor, final_image = preprocess_image(image, mtcnn, device)
64
  predicted_class, probabilities = predict(image_tensor, model, device)
65
 
66
+ # Here, customize the display of predicted_class and probabilities based on your model's specifics
67
  st.write(f"Predicted class: {predicted_class.item()}")
68
  # Display the final processed image
69
  st.image(final_image, caption='Processed Image', use_column_width=True)