SSD0907 commited on
Commit
9363d50
Β·
verified Β·
1 Parent(s): 48548a1

Update app.py

Browse files

Corrected Mismatched Labels

Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -26,12 +26,18 @@ if model_files:
26
  else:
27
  st.error("❌ No .pt file found in the unzipped model folder.")
28
 
 
 
 
 
 
 
29
  # Upload image
30
  uploaded_image = st.file_uploader("πŸ“ Upload an Image", type=["jpg", "jpeg", "png"])
31
 
32
  if uploaded_image is not None:
33
  image = Image.open(uploaded_image).convert("RGB")
34
- st.image(image, caption="Uploaded Image", width=200) # πŸ‘ˆ Adjusted image size here
35
 
36
  # Prediction button
37
  if st.button("Detect Deepfake"):
@@ -39,11 +45,12 @@ if uploaded_image is not None:
39
  results = model.predict(image)
40
 
41
  # Draw boxes on the image
42
- result_image = results[0].plot() # This plots bounding boxes
43
 
44
  # Convert to PIL Image and display
45
  result_pil = Image.fromarray(result_image[..., ::-1]) # BGR to RGB
46
- st.image(result_pil, caption="Detection Result", width=200) # πŸ‘ˆ Adjusted result image size here
47
 
48
- # Optional: Show label info
49
- st.write("πŸ”Ž Detected Labels:", results[0].names)
 
 
26
  else:
27
  st.error("❌ No .pt file found in the unzipped model folder.")
28
 
29
+ # Define corrected labels
30
+ corrected_labels = {
31
+ 0: "Fake",
32
+ 1: "Real"
33
+ }
34
+
35
  # Upload image
36
  uploaded_image = st.file_uploader("πŸ“ Upload an Image", type=["jpg", "jpeg", "png"])
37
 
38
  if uploaded_image is not None:
39
  image = Image.open(uploaded_image).convert("RGB")
40
+ st.image(image, caption="Uploaded Image", width=200)
41
 
42
  # Prediction button
43
  if st.button("Detect Deepfake"):
 
45
  results = model.predict(image)
46
 
47
  # Draw boxes on the image
48
+ result_image = results[0].plot()
49
 
50
  # Convert to PIL Image and display
51
  result_pil = Image.fromarray(result_image[..., ::-1]) # BGR to RGB
52
+ st.image(result_pil, caption="Detection Result", width=200)
53
 
54
+ # Show corrected label info
55
+ st.subheader("πŸ”Ž Detected Labels:")
56
+ st.json(corrected_labels)