maxhuber commited on
Commit
7e7d672
·
1 Parent(s): 0517016

Add reqs, adapted gradio demo, changed prediction output

Browse files
Files changed (5) hide show
  1. .gitignore +2 -1
  2. app.py +19 -51
  3. generate_prediction.py +7 -3
  4. notebook.ipynb +0 -0
  5. requirements.txt +6 -0
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  .idea/
2
  .idea
3
- data/
 
 
1
  .idea/
2
  .idea
3
+ data/
4
+ __pycache__
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from helpers import load_video_from_url, detect_deepfake
3
 
4
  theme = gr.themes.Default(
5
  primary_hue="stone",
@@ -10,61 +10,29 @@ theme = gr.themes.Default(
10
  font=[gr.themes.GoogleFont("IBM Plex Mono"), "system-ui"]
11
  )
12
 
13
- callback = gr.CSVLogger()
 
 
 
14
 
15
- with gr.Blocks(theme=theme) as demo:
16
- # DEFINE COMPONENTS
17
-
18
- # Text box for inputting Youtube URL
19
- urlInput = gr.Textbox(
20
- label="YOUTUBE VIDEO URL",
21
- value="https://www.youtube.com/watch?v=BmrUJhY9teE"
22
- )
23
-
24
- # Button for downloading the video and previewing sample frames
25
- loadVideoBtn = gr.Button("Load Video")
26
-
27
- # Text box for displaying video title
28
- videoTitle = gr.Textbox(
29
- label="VIDEO TITLE",
30
- lines=1,
31
- interactive=False
32
- )
33
 
34
- # Image Gallery for previewing sample frames
35
- sampleFrames = gr.Gallery(
36
- label="SAMPLE FRAMES",
37
- elem_id="gallery",
38
- columns=[3],
39
- rows=[1],
40
- object_fit="contain",
41
- height="auto"
42
- )
43
 
44
- # Button for generating video prediction
45
- predVideoBtn = gr.Button(value="Classify Video", visible=False)
46
-
47
- # Label for displaying prediction
48
- predOutput = gr.Label(
49
- label="DETECTED LABEL (AND CONFIDENCE LEVEL)",
50
- num_top_classes=2,
51
- visible=False
52
- )
53
-
54
- # Button for flagging the output
55
- flagBtn = gr.Button(value="Flag Output", visible=False)
56
-
57
- # DEFINE FUNCTIONS
58
- # Load video from URL, display sample frames, and enable prediction button
59
- loadVideoBtn.click(fn=load_video_from_url, inputs=[urlInput], outputs=[videoTitle, sampleFrames, predVideoBtn, predOutput])
60
 
61
- # Generate video prediction
62
- predVideoBtn.click(fn=detect_deepfake, outputs=[predOutput, flagBtn])
 
 
63
 
64
- # Define flag callback
65
- callback.setup([urlInput], "flagged_data_points")
 
66
 
67
- # Flag output
68
- flagBtn.click(fn=lambda *args: callback.flag(args), inputs=[urlInput], outputs=None)
69
 
70
  demo.launch()
 
1
  import gradio as gr
2
+ from generate_prediction import generate_prediction
3
 
4
  theme = gr.themes.Default(
5
  primary_hue="stone",
 
10
  font=[gr.themes.GoogleFont("IBM Plex Mono"), "system-ui"]
11
  )
12
 
13
+ def predict_image(image):
14
+ # Save the uploaded image to /image.jpeg
15
+ image_path = "./image.jpeg"
16
+ image.save(image_path)
17
 
18
+ # Call your model's prediction function
19
+ prediction = generate_prediction(image_path) # Assuming load_and_predict function exists in your model file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ return prediction
 
 
 
 
 
 
 
 
22
 
23
+ with gr.Blocks(theme=theme) as demo:
24
+ # DEFINE COMPONENTS
25
+ gr.Markdown("# MindReader Quantum")
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Uploading the image input
28
+ with gr.Row():
29
+ image_input = gr.Image(type="pil", label="Upload Image")
30
+ output_label = gr.Label(label="Prediction")
31
 
32
+ # Button to submit and show the prediction
33
+ with gr.Row():
34
+ submit_btn = gr.Button("Submit")
35
 
36
+ submit_btn.click(fn=predict_image, inputs=image_input, outputs=output_label)
 
37
 
38
  demo.launch()
generate_prediction.py CHANGED
@@ -1,9 +1,14 @@
1
  import torch
 
2
  import pennylane as qml
3
  from model import build_hybrid_model
4
  from data_loaders import load_single_image, load_dataset, get_class_names
5
  from helpers import imshow
6
 
 
 
 
 
7
 
8
  def generate_prediction(img_path="./image.jpeg"):
9
 
@@ -42,13 +47,12 @@ def generate_prediction(img_path="./image.jpeg"):
42
  _, preds = torch.max(output, 1)
43
 
44
  # Load Dataset for class_names
45
- # TODO: Clean this up
46
  class_names = get_class_names(load_dataset())
47
  print("OUTPUT VEC:",output[0])
48
  print("HOT INDEX:",preds[0])
49
- print(class_names)
50
 
51
  # Return Model Prediction
52
- prediction = class_names[preds[0]]
 
53
 
54
  return prediction
 
1
  import torch
2
+ import numpy as np
3
  import pennylane as qml
4
  from model import build_hybrid_model
5
  from data_loaders import load_single_image, load_dataset, get_class_names
6
  from helpers import imshow
7
 
8
+ def softmax(x):
9
+ """Compute softmax values for each set of scores in x."""
10
+ e_x = np.exp(x - np.max(x)) # Subtracting max for numerical stability
11
+ return e_x / e_x.sum(axis=0)
12
 
13
  def generate_prediction(img_path="./image.jpeg"):
14
 
 
47
  _, preds = torch.max(output, 1)
48
 
49
  # Load Dataset for class_names
 
50
  class_names = get_class_names(load_dataset())
51
  print("OUTPUT VEC:",output[0])
52
  print("HOT INDEX:",preds[0])
 
53
 
54
  # Return Model Prediction
55
+ probabilities = softmax(output[0].cpu().numpy())
56
+ prediction = {class_name.replace("_", " "): prob for class_name, prob in zip(class_names, probabilities)}
57
 
58
  return prediction
notebook.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ matplotlib==3.9.2
3
+ PennyLane==0.38.0
4
+ Pillow==10.4.0
5
+ torch==2.4.1
6
+ torchvision==0.19.1