Corin1998 commited on
Commit
8cef4d5
·
verified ·
1 Parent(s): f8c3406

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -60
app.py CHANGED
@@ -10,8 +10,7 @@ DB_PATH = "my_db"
10
  if not os.path.exists(DB_PATH):
11
  os.makedirs(DB_PATH, exist_ok=True)
12
 
13
- # 認識の負荷を下げるためのカウンター(3フレームに1回だけ識別する等)
14
- # ※Hugging Faceの負荷対策
15
  frame_count = 0
16
 
17
  # --- 関数定義 ---
@@ -19,90 +18,86 @@ frame_count = 0
19
  def register_face(image, name):
20
  """英語名のみに制限して保存する安全版"""
21
  if image is None or name.strip() == "":
22
- return "名前(英数字)を入力してください。"
23
-
24
  try:
25
- # 名前から記号や全角を排除(簡易版)
26
- safe_name = "".join([c for c in name if c.islnum()])
27
  file_path = os.path.join(DB_PATH, f"{safe_name}.jpg")
28
-
29
  image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
30
  cv2.imwrite(file_path, image_bgr)
31
-
32
- # 登録後にDeepFaceのキャッシュをクリア(新しく登録した人を即反映させるため)
33
- if os.path.exsts(os.path.join(DB_PATH, "ds_model_vgg_face.pkl")):
34
- os.remove(os.path.join(DB_PATH,"ds_model_vgg_face.pkl"))
35
-
36
- return f"「{safe_name}」さんを登録しました。②タブでカメラを起動してください。"
37
  except Exception as e:
38
  return f"エラー: {str(e)}"
39
 
40
  def track_oshi(frame):
41
- """カメラ映像(1フレーム)を受け取って推しを判定する"""
42
  if frame is None:
43
  return None
44
-
45
  global frame_count
46
  frame_count += 1
47
-
48
- # 毎フレーム解析すると重いため、2フレームに1回解析する
49
- if frame_coutn % 2 ! = 0:
 
50
  return frame
51
 
52
  try:
53
- # DeepFace.find でDB内の推しを検索
54
- # model_name="OpenFace"は比較的軽量で高速
55
  results = DeepFace.find(
56
- img_paht=frame,
57
- db_path=DB_PATH,
58
- enforce_detection=False,
59
- detector_backend='opencv', # 高速な検出機を選択
60
- silent =True,
61
  )
62
-
63
  output_frame = frame.copy()
64
-
65
  for df in results:
66
  if not df.empty:
67
- # 登録名を取得
68
- name = os.pathbasename(row['identity']).split('.')[0]
69
-
70
- # 座標名を取得
71
- x, y, w, h = int(row['source_x']), int(row['source_y']), int(row['source_w']), int(row['source_h'])
72
-
73
- #枠と名前を描写(推しを「ロックオン」している演出)
74
- cv2.rectangle(output_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
75
- cv2.rectangle(output_frame, (x, y - 30), (x + w, y), (0, 255, 0), -1)
76
- cv2.putText(output_frame, f"TARGET: {name}", (x + 5, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
77
  return output_frame
78
 
79
  except Exception as e:
80
- print(f"Tracking error: {e}")
81
  return frame
82
 
83
- # --- UI 構築 ---
84
 
85
  with gr.Blocks() as demo:
86
- gr.Markdown("# 🎥 リアルタイム推し追跡プロトタイプ")
87
-
88
  with gr.Tabs():
89
- with gr.Row():
90
- reg_in = gr.Tmage(label="推しの写真をアップロード")
91
- reg_name = gr.Textbox(label= "推しの名前(半角英数字)")
92
- reg_but = gr.Button("データベースに登録")
93
- reg_status = gr.Textbox(label="状況")
94
- reg_btn.click(register_face, inputs=[reg_in, reg_name], outputs=reg_status)
95
-
96
- with gr.TabItem("② リアルタイム追跡"):
97
- gr.Markdown("カメラを許可すると、登録した推しを自動で探し続けます。")
98
- # streaming=True にすることで、連続的に関数が呼ばれる
99
- input_cideo = gr.Image(source=["webcam"], streaming=True, label="Webカメラ映像")
100
- # 出力もImageで行う
101
- output_video = gr.Image(label= "推し追跡(ロックオン状態)")
102
-
103
- # input_videoの内容が更新されるたびにtrack_oshiを実行
104
- input_video.stream(track_oshi, inputs=[input_video], outputs=[output_video], time_limit=30)
105
-
106
- # 起動
107
- of __name__ == "__main__":
108
  demo.launch()
 
10
  if not os.path.exists(DB_PATH):
11
  os.makedirs(DB_PATH, exist_ok=True)
12
 
13
+ # 処理フレームの間隔(例:2なら2フレームに1回解析)
 
14
  frame_count = 0
15
 
16
  # --- 関数定義 ---
 
18
  def register_face(image, name):
19
  """英語名のみに制限して保存する安全版"""
20
  if image is None or name.strip() == "":
21
+ return "名前(英数字)を入力してください。"
22
+
23
  try:
24
+ # 記号やスペースを除去した安全な名前
25
+ safe_name = "".join([c for c in name if c.isalnum()])
26
  file_path = os.path.join(DB_PATH, f"{safe_name}.jpg")
27
+
28
  image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
29
  cv2.imwrite(file_path, image_bgr)
30
+
31
+ # 登録後にキャッシュを削除
32
+ if os.path.exists(os.path.join(DB_PATH, "ds_model_vgg_face.pkl")):
33
+ os.remove(os.path.join(DB_PATH, "ds_model_vgg_face.pkl"))
34
+
35
+ return f"「{safe_name}」を登録しました。②タブへ進んでください。"
36
  except Exception as e:
37
  return f"エラー: {str(e)}"
38
 
39
  def track_oshi(frame):
40
+ """カメラ映像から推しを判定する"""
41
  if frame is None:
42
  return None
43
+
44
  global frame_count
45
  frame_count += 1
46
+
47
+ # 【修正箇所】演算子と変数名を修正
48
+ # 2フレームに1回だけ重い処理を行うことで動作を軽くする
49
+ if frame_count % 2 != 0:
50
  return frame
51
 
52
  try:
53
+ # 顔の検索(検出器に高速なopencvを指定)
 
54
  results = DeepFace.find(
55
+ img_path=frame,
56
+ db_path=DB_PATH,
57
+ enforce_detection=False,
58
+ detector_backend='opencv',
59
+ silent=True
60
  )
61
+
62
  output_frame = frame.copy()
63
+
64
  for df in results:
65
  if not df.empty:
66
+ for _, row in df.iterrows():
67
+ name = os.path.basename(row['identity']).split('.')[0]
68
+ x, y, w, h = int(row['source_x']), int(row['source_y']), int(row['source_w']), int(row['source_h'])
69
+
70
+ # 枠と名前を描画
71
+ cv2.rectangle(output_frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
72
+ cv2.putText(output_frame, f"TARGET: {name}", (x, y-10),
73
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
74
+
 
75
  return output_frame
76
 
77
  except Exception as e:
 
78
  return frame
79
 
80
+ # --- UI (Gradio) ---
81
 
82
  with gr.Blocks() as demo:
83
+ gr.Markdown("# 🎥 リアルタイム推し追跡")
84
+
85
  with gr.Tabs():
86
+ with gr.TabItem("① 推し登録"):
87
+ with gr.Row():
88
+ reg_in = gr.Image(label="推しの写真")
89
+ reg_name = gr.Textbox(label="名前(半角英数字)", value="oshi")
90
+ reg_btn = gr.Button("登録")
91
+ reg_status = gr.Textbox(label="ステータス")
92
+ reg_btn.click(register_face, inputs=[reg_in, reg_name], outputs=reg_status)
93
+
94
+ with gr.TabItem("② リアルタイム追跡"):
95
+ gr.Markdown("カメラを起動して自分や推しを映してください。")
96
+ input_video = gr.Image(sources=["webcam"], streaming=True)
97
+ output_video = gr.Image(label="解析結果")
98
+
99
+ # ストリーミング設定
100
+ input_video.stream(track_oshi, inputs=[input_video], outputs=[output_video])
101
+
102
+ if __name__ == "__main__":
 
 
103
  demo.launch()