MatanKriel commited on
Commit
36e2a11
Β·
verified Β·
1 Parent(s): 1933b1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -43
app.py CHANGED
@@ -7,60 +7,54 @@ from PIL import Image
7
  import io
8
 
9
  # --- CONFIGURATION ---
10
- # Using the lightweight OpenAI CLIP model
11
- MODEL_ID = "openai/clip-vit-base-patch32"
12
  DATA_FILE = "food_embeddings_clip.parquet"
13
 
 
 
 
 
14
  print(f"⏳ Loading {MODEL_ID} and Data...")
15
 
16
- # 1. Load Model (CLIP class is specific and lighter)
17
  model = CLIPModel.from_pretrained(MODEL_ID)
18
  processor = CLIPProcessor.from_pretrained(MODEL_ID)
19
 
20
- # 2. Load the "Memory"
21
  try:
22
  df = pd.read_parquet(DATA_FILE)
 
 
 
23
  except FileNotFoundError:
24
- raise RuntimeError(f"❌ ERROR: Could not find {DATA_FILE}. Did you upload it to Files?")
25
-
26
- # 3. Prepare Database
27
- all_vectors = np.stack(df['embedding'].to_numpy())
28
- db_features = torch.tensor(all_vectors)
29
 
30
  print("βœ… System Ready!")
31
 
 
32
  def search(text_query, image_query):
33
- # A. Decide Input Type
 
 
 
34
  if image_query:
35
  inputs = processor(images=image_query, return_tensors="pt", padding=True)
36
- get_feat_func = model.get_image_features
37
- elif text_query:
38
- inputs = processor(text=[text_query], return_tensors="pt", padding=True)
39
- get_feat_func = model.get_text_features
40
  else:
41
- return None
 
42
 
43
- # B. Run Inference
44
  with torch.no_grad():
45
- query_vec = get_feat_func(**inputs)
 
46
 
47
- # C. Search Logic
48
- # 1. Normalize Query
49
- query_vec = query_vec / query_vec.norm(p=2, dim=-1, keepdim=True)
50
-
51
- # 2. Dot Product
52
- scores = torch.mm(query_vec, db_features.T)
53
-
54
- # 3. Top 5 Results
55
- top_scores, top_indices = torch.topk(scores, k=5)
56
-
57
- # D. Fetch Results
58
  results = []
59
  for idx, score in zip(top_indices[0], top_scores[0]):
60
- idx = idx.item()
61
- row = df.iloc[idx]
62
 
63
- # Load Image from Parquet bytes
64
  img_data = row['image']
65
  if isinstance(img_data, dict) and 'bytes' in img_data:
66
  img = Image.open(io.BytesIO(img_data['bytes']))
@@ -68,23 +62,73 @@ def search(text_query, image_query):
68
  img = img_data
69
 
70
  results.append((img, f"{row['label_name']} ({score.item():.2f})"))
71
-
72
  return results
73
 
74
- # --- INTERFACE ---
75
- with gr.Blocks(title="Food Search (CLIP)") as demo:
76
- gr.Markdown("# πŸ• Lightweight Food Search (CLIP)")
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  with gr.Row():
79
- with gr.Column():
80
- txt_input = gr.Textbox(label="Search by Text", placeholder="e.g. 'sushi on a boat'")
81
- img_input = gr.Image(type="pil", label="Or Search by Image")
82
- btn = gr.Button("Search", variant="primary")
 
 
 
 
 
 
 
 
 
 
83
 
84
- with gr.Column():
85
- gallery = gr.Gallery(label="Top Matches")
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- btn.click(fn=search, inputs=[txt_input, img_input], outputs=gallery)
 
 
88
 
89
- # Bind to 0.0.0.0 for Spaces
90
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
7
  import io
8
 
9
  # --- CONFIGURATION ---
10
+ MODEL_ID = "openai/clip-vit-base-patch32"
 
11
  DATA_FILE = "food_embeddings_clip.parquet"
12
 
13
+ # πŸŽ₯ PASTE YOUR YOUTUBE VIDEO ID HERE
14
+ # (e.g. if link is https://www.youtube.com/watch?v=dQw4w9WgXcQ, the ID is dQw4w9WgXcQ)
15
+ YOUTUBE_ID = "IXeIxYHi0Es"
16
+
17
  print(f"⏳ Loading {MODEL_ID} and Data...")
18
 
19
+ # 1. Load Model
20
  model = CLIPModel.from_pretrained(MODEL_ID)
21
  processor = CLIPProcessor.from_pretrained(MODEL_ID)
22
 
23
+ # 2. Load Data
24
  try:
25
  df = pd.read_parquet(DATA_FILE)
26
+ # Prepare Vectors
27
+ all_vectors = np.stack(df['embedding'].to_numpy())
28
+ db_features = torch.tensor(all_vectors)
29
  except FileNotFoundError:
30
+ raise RuntimeError(f"❌ ERROR: Could not find {DATA_FILE}. Did you upload it?")
 
 
 
 
31
 
32
  print("βœ… System Ready!")
33
 
34
+ # --- SEARCH LOGIC ---
35
  def search(text_query, image_query):
36
+ if not text_query and not image_query:
37
+ return []
38
+
39
+ # A. Determine Input
40
  if image_query:
41
  inputs = processor(images=image_query, return_tensors="pt", padding=True)
42
+ get_feat = model.get_image_features
 
 
 
43
  else:
44
+ inputs = processor(text=[text_query], return_tensors="pt", padding=True)
45
+ get_feat = model.get_text_features
46
 
47
+ # B. Inference & Search
48
  with torch.no_grad():
49
+ query_vec = get_feat(**inputs)
50
+ top_scores, top_indices = torch.topk(scores, k=5)
51
 
52
+ # C. Format Results
 
 
 
 
 
 
 
 
 
 
53
  results = []
54
  for idx, score in zip(top_indices[0], top_scores[0]):
55
+ row = df.iloc[idx.item()]
 
56
 
57
+ # Load Image
58
  img_data = row['image']
59
  if isinstance(img_data, dict) and 'bytes' in img_data:
60
  img = Image.open(io.BytesIO(img_data['bytes']))
 
62
  img = img_data
63
 
64
  results.append((img, f"{row['label_name']} ({score.item():.2f})"))
 
65
  return results
66
 
67
+ # --- APP INTERFACE (The Original Design) ---
68
+ # We use a 'Soft' theme for a professional look
69
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Food Search") as demo:
70
 
71
+ # 1. Header Section
72
+ gr.Markdown(
73
+ """
74
+ # πŸ” AI Food Search Engine
75
+ ### Powered by OpenAI CLIP & Hugging Face
76
+ Search through 5,000 food images using natural language or reference images.
77
+ """
78
+ )
79
+
80
+ # 2. YouTube Demo Section (Embedded Player)
81
+ if YOUTUBE_ID and YOUTUBE_ID != "YOUR_YOUTUBE_ID_HERE":
82
+ gr.HTML(
83
+ f"""
84
+ <div style="display: flex; justify-content: center; margin-bottom: 20px;">
85
+ <iframe width="560" height="315"
86
+ src="https://www.youtube.com/embed/{YOUTUBE_ID}"
87
+ title="YouTube video player" frameborder="0"
88
+ allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
89
+ allowfullscreen></iframe>
90
+ </div>
91
+ """
92
+ )
93
+ else:
94
+ gr.Info("ℹ️ Add your YouTube ID in the code to display the video here.")
95
+
96
+ # 3. Main Search Interface
97
  with gr.Row():
98
+ # Left Column: Inputs
99
+ with gr.Column(scale=1):
100
+ gr.Markdown("### πŸ” Your Query")
101
+ txt_input = gr.Textbox(
102
+ label="Search by Text",
103
+ placeholder="e.g. 'spicy tacos with lime'",
104
+ show_label=True
105
+ )
106
+ gr.Markdown("**OR**")
107
+ img_input = gr.Image(
108
+ type="pil",
109
+ label="Search by Image",
110
+ height=300
111
+ )
112
 
113
+ search_btn = gr.Button("πŸš€ Find Food", variant="primary", size="lg")
114
+
115
+ # Right Column: Results
116
+ with gr.Column(scale=2):
117
+ gr.Markdown("### πŸ• Top Matches")
118
+ gallery = gr.Gallery(
119
+ label="Results",
120
+ columns=3,
121
+ height="auto",
122
+ object_fit="cover"
123
+ )
124
+
125
+ # 4. Footer / Credits
126
+ gr.Markdown("---")
127
+ gr.Markdown(f"*Model: {MODEL_ID} | Dataset: Food-101 (Subset)*")
128
 
129
+ # Event Listeners (Enter Key + Button Click)
130
+ txt_input.submit(search, inputs=[txt_input, img_input], outputs=gallery)
131
+ search_btn.click(search, inputs=[txt_input, img_input], outputs=gallery)
132
 
133
+ # Launch
134
  demo.launch(server_name="0.0.0.0", server_port=7860)