huzpsb commited on
Commit
ab697f7
·
verified ·
1 Parent(s): 53bd91d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +13 -6
  2. susu_s1.pth +3 -0
  3. susu_s2.pth +3 -0
app.py CHANGED
@@ -6,9 +6,10 @@ from PIL import Image
6
  # 0 1 2 3 4
7
  # Name, point-th, half-false-positive, balanced, half-true-positive
8
  presets = { #
 
 
9
  ("LatentDiffusion", 0.97, 0.65, 0.78, 0.88), #
10
  ("SdXl", 0.6, 0.49, 0.83, 0.89), #
11
- ("General", 0.5, 0.55, 0.87, 0.94), #
12
  ("DALL-E", 0.03, 0.8932, 0.9876, 0.999), #
13
  }
14
  # These hyperparameters are based on the model's performance on the validation set.
@@ -16,12 +17,16 @@ presets = { #
16
  if torch.cuda.is_available():
17
  device = torch.device('cuda')
18
  model = torch.load('m3_80.pth').to(device)
19
- refiner = torch.load('model2_p4.pth').to(device)
 
20
  else:
21
  device = torch.device('cpu')
22
  model = torch.load('m3_80.pth', map_location=device)
23
- refiner = torch.load('model2_p4.pth', map_location=device)
 
 
24
 
 
25
 
26
  def process_image(image, p_name):
27
  image = image.convert('RGB')
@@ -41,11 +46,13 @@ def process_image(image, p_name):
41
  np_img = tmp[begin_h:begin_h + 512, begin_w:begin_w + 512]
42
  torch_img = torch.tensor(np_img).float().unsqueeze(0) / 255.0
43
  torch_img = torch_img.to(device)
44
- y1 = model(torch_img).detach().cpu().numpy().squeeze()
45
- y1 = np.where(y1 > 0.5, 1, 0)
46
  x2 = torch.tensor(y1).float().unsqueeze(0).to(device)
47
- y2 = float(refiner(x2))
48
  y2 = min(max(y2, 0), 1)
 
 
49
  if y2 > 0.5:
50
  return "SlGeneral : AI Generated / R=" + str(round(y2 * 100, 2)) + " / Rule-based", None
51
  else:
 
6
  # 0 1 2 3 4
7
  # Name, point-th, half-false-positive, balanced, half-true-positive
8
  presets = { #
9
+ ("General (NoSl)", 0.5, 0.55, 0.87, 0.94), #
10
+ ("General", 0.5, 0.55, 0.87, 0.94), #
11
  ("LatentDiffusion", 0.97, 0.65, 0.78, 0.88), #
12
  ("SdXl", 0.6, 0.49, 0.83, 0.89), #
 
13
  ("DALL-E", 0.03, 0.8932, 0.9876, 0.999), #
14
  }
15
  # These hyperparameters are based on the model's performance on the validation set.
 
17
  if torch.cuda.is_available():
18
  device = torch.device('cuda')
19
  model = torch.load('m3_80.pth').to(device)
20
+ refiner1 = torch.load('susu_s1.pth').to(device)
21
+ refiner2 = torch.load('susu_s2.pth').to(device)
22
  else:
23
  device = torch.device('cpu')
24
  model = torch.load('m3_80.pth', map_location=device)
25
+ refiner1 = torch.load('susu_s1.pth', map_location=device)
26
+ refiner2 = torch.load('susu_s2.pth', map_location=device)
27
+
28
 
29
+ # ! susu_s1 model is not an e2e model thus you shall get nearly no accuracy if you use it alone !
30
 
31
  def process_image(image, p_name):
32
  image = image.convert('RGB')
 
46
  np_img = tmp[begin_h:begin_h + 512, begin_w:begin_w + 512]
47
  torch_img = torch.tensor(np_img).float().unsqueeze(0) / 255.0
48
  torch_img = torch_img.to(device)
49
+ y1 = refiner1(torch_img).detach().cpu().numpy().squeeze()
50
+ y1 = np.where(y1 > 0.837295, 1, 0) # @See -> tr_2 block -1
51
  x2 = torch.tensor(y1).float().unsqueeze(0).to(device)
52
+ y2 = float(refiner2(x2))
53
  y2 = min(max(y2, 0), 1)
54
+ if 0.2 < y2 < 0.8:
55
+ return "SlGeneral : No Comment, sorry! / P_A=" + str(round(y2 * 100, 2)) + " / Rule-based", None
56
  if y2 > 0.5:
57
  return "SlGeneral : AI Generated / R=" + str(round(y2 * 100, 2)) + " / Rule-based", None
58
  else:
susu_s1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60e1376b25bb0eceda123c7772294b124b7ef92a58d7772f6c070a2e7bffa89e
3
+ size 1468627
susu_s2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59ccdd8329f2380c6e82cd16d787adf596d5f4aded492605abaa8746a70c347e
3
+ size 11726371