Go-Raw commited on
Commit
bf2ed0a
·
verified ·
1 Parent(s): 721f0f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -14
app.py CHANGED
@@ -7,30 +7,26 @@ import requests
7
  from io import BytesIO
8
  import gradio as gr
9
 
10
- # Silence pandas warning
11
  pd.options.mode.chained_assignment = None
12
 
13
- # Load pre-computed meme embeddings
14
  embeddings = pickle.load(open(
15
  hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="meme-embeddings.pkl"), "rb"))
16
 
17
- # Load meme metadata (must include 'id' and 'url' columns)
18
  df = pd.read_csv(hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="input.csv"))
19
 
20
- # Load sentence transformer model
21
  model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
22
 
23
- # Meme search logic
24
  def generate_memes(prompt):
25
  prompt_embedding = model.encode(prompt, convert_to_tensor=True)
26
  hits = util.semantic_search(prompt_embedding, embeddings, top_k=6)
27
  hits_df = pd.DataFrame(hits[0], columns=["corpus_id", "score"])
28
-
29
- # Filter top matching memes
30
  matched_ids = hits_df["corpus_id"]
31
  matched_memes = df[df["id"].isin(matched_ids)]
32
-
33
- # Fetch and return meme images
34
  images = []
35
  for url in matched_memes["url"]:
36
  try:
@@ -41,16 +37,14 @@ def generate_memes(prompt):
41
  print(f"Error loading image {url}: {e}")
42
  return images
43
 
44
- # Gradio UI components
45
  input_textbox = gr.Textbox(lines=1, label="Search something cool")
46
  output_gallery = gr.Gallery(label="Retrieved Memes", columns=3, rows=2, height="auto")
47
 
48
- # App metadata
49
  title = "Semantic Search for Memes"
50
  description = "Search memes from a dataset of ~6k memes using semantic similarity. [GitHub Repo](https://github.com/bhavya-giri/retrieving-memes)"
51
  examples = ["Get Shreked", "Going Crazy", "Spiderman is my teacher"]
52
 
53
- # Interface setup
54
  iface = gr.Interface(
55
  fn=generate_memes,
56
  inputs=input_textbox,
@@ -59,9 +53,7 @@ iface = gr.Interface(
59
  cache_examples=True,
60
  title=title,
61
  description=description,
62
- interpretation='default',
63
  enable_queue=True
64
  )
65
 
66
- # Run the app
67
  iface.launch()
 
7
  from io import BytesIO
8
  import gradio as gr
9
 
 
10
  pd.options.mode.chained_assignment = None
11
 
12
+ # Load precomputed embeddings
13
  embeddings = pickle.load(open(
14
  hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="meme-embeddings.pkl"), "rb"))
15
 
16
+ # Load meme metadata
17
  df = pd.read_csv(hf_hub_download("bhavyagiri/semantic-memes", repo_type="dataset", filename="input.csv"))
18
 
19
+ # Load model
20
  model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
21
 
22
+ # Meme search function
23
  def generate_memes(prompt):
24
  prompt_embedding = model.encode(prompt, convert_to_tensor=True)
25
  hits = util.semantic_search(prompt_embedding, embeddings, top_k=6)
26
  hits_df = pd.DataFrame(hits[0], columns=["corpus_id", "score"])
 
 
27
  matched_ids = hits_df["corpus_id"]
28
  matched_memes = df[df["id"].isin(matched_ids)]
29
+
 
30
  images = []
31
  for url in matched_memes["url"]:
32
  try:
 
37
  print(f"Error loading image {url}: {e}")
38
  return images
39
 
40
+ # UI
41
  input_textbox = gr.Textbox(lines=1, label="Search something cool")
42
  output_gallery = gr.Gallery(label="Retrieved Memes", columns=3, rows=2, height="auto")
43
 
 
44
  title = "Semantic Search for Memes"
45
  description = "Search memes from a dataset of ~6k memes using semantic similarity. [GitHub Repo](https://github.com/bhavya-giri/retrieving-memes)"
46
  examples = ["Get Shreked", "Going Crazy", "Spiderman is my teacher"]
47
 
 
48
  iface = gr.Interface(
49
  fn=generate_memes,
50
  inputs=input_textbox,
 
53
  cache_examples=True,
54
  title=title,
55
  description=description,
 
56
  enable_queue=True
57
  )
58
 
 
59
  iface.launch()