ash12321 commited on
Commit
5e8ae74
·
verified ·
1 Parent(s): 9f99a53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -61
app.py CHANGED
@@ -1,72 +1,119 @@
1
  # app.py
2
- import torch
3
- from torchvision import transforms, models
4
- from PIL import Image
5
  import gradio as gr
6
- from transformers import pipeline
 
7
 
8
- # -----------------------------
9
- # CONFIG
10
- # -----------------------------
11
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- IMAGE_MODEL_PATH = "models/deepfake_model.pth"
 
 
13
 
14
- # -----------------------------
15
- # LOAD DEEPFAKE IMAGE MODEL
16
- # -----------------------------
17
- deepfake_model = models.resnet18(pretrained=True)
18
- num_features = deepfake_model.fc.in_features
19
- deepfake_model.fc = torch.nn.Linear(num_features, 2)
20
- deepfake_model.load_state_dict(torch.load(IMAGE_MODEL_PATH, map_location=DEVICE))
21
- deepfake_model.eval()
22
- deepfake_model.to(DEVICE)
23
 
24
- # -----------------------------
25
- # LOAD TEXT MODEL (already working)
26
- # -----------------------------
27
- text_model = pipeline("text-classification", model="mrm8488/bert-tiny-finetuned-fake-news")
 
 
 
28
 
29
- # -----------------------------
30
- # IMAGE PREDICTION FUNCTION
31
- # -----------------------------
32
- image_transform = transforms.Compose([
33
- transforms.Resize((224, 224)),
34
- transforms.ToTensor(),
35
- ])
36
 
37
- def predict_image(image):
38
- img = Image.fromarray(image)
39
- img = image_transform(img).unsqueeze(0).to(DEVICE)
40
- with torch.no_grad():
41
- output = deepfake_model(img)
42
- probs = torch.softmax(output, dim=1)
43
- label = "Real" if torch.argmax(probs) == 0 else "Fake"
44
- confidence = torch.max(probs).item()
45
- return f"{label} ({confidence*100:.2f}% confidence)"
46
 
47
- # -----------------------------
48
- # TEXT PREDICTION FUNCTION
49
- # -----------------------------
50
- def predict_text(text):
51
- result = text_model(text)[0]
52
- return f"{result['label']} ({result['score']*100:.2f}% confidence)"
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # -----------------------------
55
- # BUILD UI
56
- # -----------------------------
57
- with gr.Blocks() as demo:
58
- gr.Markdown("## Deepfake & Fake News Detector")
59
-
60
- with gr.Tab("Image Detection"):
61
- image_input = gr.Image(type="numpy")
62
- image_output = gr.Textbox(label="Prediction")
63
- image_button = gr.Button("Predict Image")
64
- image_button.click(predict_image, inputs=image_input, outputs=image_output)
65
-
66
- with gr.Tab("Text Detection"):
67
- text_input = gr.Textbox(label="Enter Text")
68
- text_output = gr.Textbox(label="Prediction")
69
- text_button = gr.Button("Predict Text")
70
- text_button.click(predict_text, inputs=text_input, outputs=text_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  demo.launch()
 
1
  # app.py
 
 
 
2
  import gradio as gr
3
+ from transformers import pipeline, logging
4
+ logging.set_verbosity_error() # mute transformers INFO logs to keep the UI logs clean
5
 
6
+ # -----------------------
7
+ # NOTE: These are public models known to load on Spaces.
8
+ # - text model: small DistilBERT sentiment model (used as a safe demo for text "credibility")
9
+ # - image model: ViT image-classifier (generic). Replace later with a custom deepfake model when ready.
10
+ # -----------------------
11
+ TEXT_MODEL_ID = "distilbert-base-uncased-finetuned-sst-2-english"
12
+ IMAGE_MODEL_ID = "google/vit-base-patch16-224"
13
 
14
+ # Load pipelines (will download weights on first run)
15
+ try:
16
+ text_pipe = pipeline("text-classification", model=TEXT_MODEL_ID)
17
+ except Exception as e:
18
+ text_pipe = None
19
+ text_load_error = str(e)
20
+ else:
21
+ text_load_error = None
 
22
 
23
+ try:
24
+ image_pipe = pipeline("image-classification", model=IMAGE_MODEL_ID)
25
+ except Exception as e:
26
+ image_pipe = None
27
+ image_load_error = str(e)
28
+ else:
29
+ image_load_error = None
30
 
31
+ # Friendly mapping (different text models return different label names)
32
+ TEXT_FRIENDLY = {
33
+ "NEGATIVE": "Not credible / Fake (demo)",
34
+ "LABEL_0": "Not credible / Fake (demo)",
35
+ "POSITIVE": "Credible / Real (demo)",
36
+ "LABEL_1": "Credible / Real (demo)"
37
+ }
38
 
39
+ def friendly_text_label(raw_label: str) -> str:
40
+ if raw_label is None:
41
+ return "Unknown"
42
+ raw = str(raw_label).upper()
43
+ return TEXT_FRIENDLY.get(raw, raw_label)
 
 
 
 
44
 
45
+ def classify_text(text: str):
46
+ if text is None or not str(text).strip():
47
+ return "Please paste some text to analyze.", {}
48
+ if text_pipe is None:
49
+ return f"Text model failed to load: {text_load_error}", {}
50
+ try:
51
+ # request top 2 classes for a small confidence breakdown
52
+ preds = text_pipe(text, top_k=2)
53
+ # preds is a list of dicts like {'label':'POSITIVE','score':0.98}
54
+ label_dict = {}
55
+ for p in preds:
56
+ lab = friendly_text_label(p.get("label"))
57
+ label_dict[lab] = float(p.get("score", 0.0))
58
+ # choose top
59
+ top_lab = max(label_dict.items(), key=lambda kv: kv[1])
60
+ summary = f"{top_lab[0]} ({top_lab[1]*100:.2f}%)"
61
+ return summary, label_dict
62
+ except Exception as e:
63
+ return f"Error during text classification: {e}", {}
64
 
65
+ def classify_image(image):
66
+ if image is None:
67
+ return "Please upload an image.", {}
68
+ if image_pipe is None:
69
+ return f"Image model failed to load: {image_load_error}", {}
70
+ try:
71
+ preds = image_pipe(image, top_k=5)
72
+ label_dict = {p['label']: float(p['score']) for p in preds}
73
+ top_lab = max(label_dict.items(), key=lambda kv: kv[1])
74
+ summary = f"{top_lab[0]} ({top_lab[1]*100:.2f}%)"
75
+ return summary, label_dict
76
+ except Exception as e:
77
+ return f"Error during image classification: {e}", {}
78
+
79
+ # --- UI ---
80
+ with gr.Blocks(title="AI Detector (Text + Image)") as demo:
81
+ gr.Markdown("## 🔎 AI Detector\nText (credibility demo) and Image (generic classifier).")
82
+ gr.Markdown(
83
+ "> This app uses public models that load in Spaces. When you have your own trained deepfake model, "
84
+ "you can swap the image model ID in `app.py` to point at your Hugging Face model."
85
+ )
86
+
87
+ with gr.Row():
88
+ with gr.Column(scale=1):
89
+ gr.Markdown("### 📝 Text Analysis")
90
+ txt = gr.Textbox(lines=6, placeholder="Paste text here...", label="Input Text")
91
+ txt_result = gr.Textbox(label="Summary")
92
+ txt_probs = gr.Label(label="Confidence (top 2)")
93
+ with gr.Row():
94
+ btn_txt = gr.Button("Analyze Text")
95
+ btn_txt.click(classify_text, inputs=txt, outputs=[txt_result, txt_probs])
96
+ btn_txt_clear = gr.Button("Clear")
97
+ btn_txt_clear.click(lambda: ("", {}, ""), outputs=[txt, txt_result, txt_probs])
98
+
99
+ with gr.Column(scale=1):
100
+ gr.Markdown("### 🖼️ Image Analysis")
101
+ img = gr.Image(type="pil", label="Upload Image")
102
+ img_result = gr.Textbox(label="Summary")
103
+ img_probs = gr.Label(label="Top-5 Confidence")
104
+ with gr.Row():
105
+ btn_img = gr.Button("Analyze Image")
106
+ btn_img.click(classify_image, inputs=img, outputs=[img_result, img_probs])
107
+ btn_img_clear = gr.Button("Clear")
108
+ btn_img_clear.click(lambda: (None, "", {}), outputs=[img, img_result, img_probs])
109
+
110
+ # Footer: show load errors if any
111
+ if text_load_error or image_load_error:
112
+ with gr.Column():
113
+ gr.Markdown("**Model load warnings:**")
114
+ if text_load_error:
115
+ gr.Markdown(f"- Text model load error: `{text_load_error}`")
116
+ if image_load_error:
117
+ gr.Markdown(f"- Image model load error: `{image_load_error}`")
118
 
119
  demo.launch()