XORE21 commited on
Commit
6af2ae5
·
verified ·
1 Parent(s): e79654b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -27
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import base64
2
  import cv2
3
  import numpy as np
4
- import os
5
- from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  from collections import defaultdict
8
 
@@ -19,25 +19,30 @@ def root():
19
  class Input(BaseModel):
20
  image_base64: str
21
 
22
- def save_base64_image_cv(base64_str, output_path="final.png"):
23
- img_data = base64.b64decode(base64_str)
24
- nparr = np.frombuffer(img_data, np.uint8)
25
- img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
26
-
27
- if img.shape[2] == 4:
28
- alpha = img[:, :, 3] / 255.0
29
- rgb = img[:, :, :3]
30
- white_bg = np.ones_like(rgb, dtype=np.uint8) * 255
31
- img = (rgb * alpha[:, :, None] + white_bg * (1 - alpha[:, :, None])).astype(np.uint8)
32
 
33
- cv2.imwrite(output_path, img)
 
 
 
 
 
 
 
 
34
 
35
- def extract_icon_positions(image_path):
36
- img = cv2.imread(image_path)
37
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
38
  _, thresh = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY_INV)
39
  contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
40
-
41
  icons, pos = [], []
42
  for c in contours:
43
  x, y, w, h = cv2.boundingRect(c)
@@ -52,16 +57,24 @@ def img_hash(img):
52
  return (img > img.mean()).astype(np.uint8).flatten()
53
 
54
  def find_rarest(icon_features, positions):
 
 
 
55
  hashes = [img_hash(i) for i in icon_features]
56
  groups = defaultdict(list)
57
-
58
  for i, h in enumerate(hashes):
59
- for g in groups.values():
60
- if np.sum(h != hashes[g[0]]) < 3:
61
- g.append(i)
 
 
62
  break
63
- else:
64
  groups[len(groups)] = [i]
 
 
 
65
 
66
  idx = min(groups.values(), key=len)[0]
67
  return positions[idx]
@@ -69,10 +82,19 @@ def find_rarest(icon_features, positions):
69
  @app.post("/solve")
70
  def solve(data: Input):
71
  try:
72
- save_base64_image_cv(data.image_base64)
73
- icons, pos = extract_icon_positions("final.png")
 
 
 
 
 
74
  x, y = find_rarest(icons, pos)
75
- return {"x": x, "y": y}
76
- finally:
77
- if os.path.exists("final.png"):
78
- os.remove("final.png")
 
 
 
 
 
1
  import base64
2
  import cv2
3
  import numpy as np
4
+ import uvicorn
5
+ from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from collections import defaultdict
8
 
 
19
  class Input(BaseModel):
20
  image_base64: str
21
 
22
+ def preprocess_image_memory(base64_str):
23
+ try:
24
+ img_data = base64.b64decode(base64_str)
25
+ nparr = np.frombuffer(img_data, np.uint8)
26
+ img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
27
+
28
+ if img is None:
29
+ raise ValueError("Invalid/corrupt image")
 
 
30
 
31
+ if len(img.shape) == 3 and img.shape[2] == 4:
32
+ alpha = img[:, :, 3] / 255.0
33
+ rgb = img[:, :, :3]
34
+ white_bg = np.ones_like(rgb, dtype=np.uint8) * 255
35
+ img = (rgb * alpha[:, :, None] + white_bg * (1 - alpha[:, :, None])).astype(np.uint8)
36
+
37
+ return img
38
+ except Exception as e:
39
+ raise ValueError(f"Error decoding image: {str(e)}")
40
 
41
+ def extract_icon_positions(img):
 
42
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
43
  _, thresh = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY_INV)
44
  contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
45
+
46
  icons, pos = [], []
47
  for c in contours:
48
  x, y, w, h = cv2.boundingRect(c)
 
57
  return (img > img.mean()).astype(np.uint8).flatten()
58
 
59
  def find_rarest(icon_features, positions):
60
+ if not icon_features:
61
+ return None, None
62
+
63
  hashes = [img_hash(i) for i in icon_features]
64
  groups = defaultdict(list)
65
+
66
  for i, h in enumerate(hashes):
67
+ found = False
68
+ for label, group in groups.items():
69
+ if np.sum(h != hashes[group[0]]) < 3:
70
+ group.append(i)
71
+ found = True
72
  break
73
+ if not found:
74
  groups[len(groups)] = [i]
75
+
76
+ if not groups:
77
+ return None, None
78
 
79
  idx = min(groups.values(), key=len)[0]
80
  return positions[idx]
 
82
  @app.post("/solve")
83
  def solve(data: Input):
84
  try:
85
+ img = preprocess_image_memory(data.image_base64)
86
+
87
+ icons, pos = extract_icon_positions(img)
88
+
89
+ if not icons:
90
+ return {"error": "No icons found", "x": 0, "y": 0}
91
+
92
  x, y = find_rarest(icons, pos)
93
+
94
+ return {"x": int(x), "y": int(y)}
95
+
96
+ except Exception as e:
97
+ return {"error": str(e)}
98
+
99
+ if __name__ == "__main__":
100
+ uvicorn.run(app, host="0.0.0.0", port=7860)