Rujit commited on
Commit
aeea6fb
·
verified ·
1 Parent(s): 79035ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -53
app.py CHANGED
@@ -1,53 +1,53 @@
1
- import gradio as gr
2
- import torch
3
- from torchvision import models, transforms
4
- from PIL import Image
5
- import torch.nn.functional as F
6
- import torch.nn as nn
7
-
8
- # Class labels
9
- class_names = ['fake', 'real']
10
-
11
- # Image transform
12
- data_transforms = transforms.Compose([
13
- transforms.Resize((224, 224)),
14
- transforms.ToTensor(),
15
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
16
- ])
17
-
18
- # Load model
19
- def load_model():
20
- model = models.densenet121(weights='IMAGENET1K_V1')
21
- model.classifier = nn.Sequential(
22
- nn.Linear(1024, 512),
23
- nn.ReLU(),
24
- nn.Dropout(0.5),
25
- nn.Linear(512, 2)
26
- )
27
- device = torch.device('cpu') # Use CPU for Hugging Face
28
- model = model.to(device)
29
- checkpoint = torch.load("best_model.pth", map_location=device)
30
- model.load_state_dict(checkpoint['model_state_dict'])
31
- model.eval()
32
- return model, device
33
-
34
- model, device = load_model()
35
-
36
- # Inference function
37
- def predict(image):
38
- if image.mode == "RGBA":
39
- image = image.convert("RGB")
40
-
41
- image = data_transforms(image).unsqueeze(0).to(device)
42
- with torch.no_grad():
43
- outputs = model(image)
44
- probs = F.softmax(outputs, dim=1)
45
- conf, pred = torch.max(probs, 1)
46
-
47
- label = class_names[pred.item()]
48
- confidence = f"{conf.item() * 100:.2f}%"
49
- return f"{label} ({confidence})"
50
-
51
- # Gradio interface
52
- demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Fake/Real Image Classifier")
53
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+ import torch.nn as nn
7
+
8
+ # Class labels
9
+ class_names = ['fake', 'real']
10
+
11
+ # Image transform
12
+ data_transforms = transforms.Compose([
13
+ transforms.Resize((224, 224)),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
16
+ ])
17
+
18
+ # Load model
19
+ def load_model():
20
+ model = models.densenet121(weights='IMAGENET1K_V1')
21
+ model.classifier = nn.Sequential(
22
+ nn.Linear(1024, 512),
23
+ nn.ReLU(),
24
+ nn.Dropout(0.5),
25
+ nn.Linear(512, 2)
26
+ )
27
+ device = torch.device('cpu') # Use CPU for Hugging Face
28
+ model = model.to(device)
29
+ checkpoint = torch.load("best_model.pth", map_location=device)
30
+ model.load_state_dict(checkpoint['model_state_dict'])
31
+ model.eval()
32
+ return model, device
33
+
34
+ model, device = load_model()
35
+
36
+ # Inference function
37
+ def predict(image):
38
+ if image.mode == "RGBA":
39
+ image = image.convert("RGB")
40
+
41
+ image = data_transforms(image).unsqueeze(0).to(device)
42
+ with torch.no_grad():
43
+ outputs = model(image)
44
+ probs = F.softmax(outputs, dim=1)
45
+ conf, pred = torch.max(probs, 1)
46
+
47
+ label = class_names[pred.item()]
48
+ confidence = f"{conf.item() * 100:.2f}%"
49
+ return f"{label} ({confidence})"
50
+
51
+ # Gradio interface
52
+ demo = gr.Interface(fn=predict, inputs="image", outputs="text", api_name="predict")
53
+ demo.launch()