leadr64 commited on
Commit
3a7a1b1
·
1 Parent(s): 8e7416e

Ajouter le script Gradio et les dépendances

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -1,47 +1,50 @@
1
  import os
2
-
3
  import gradio as gr
4
  from qdrant_client import QdrantClient
5
  from transformers import ClapModel, ClapProcessor
 
 
6
  QDRANT_URL = os.getenv("QDRANT_URL")
7
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
 
 
 
 
8
  # Loading the Qdrant DB in local ###################################################################
9
- client = QdrantClient(QDRANT_URL, QDRANT_API_KEY)
10
  print("[INFO] Client created...")
11
 
12
  # loading the model
13
  print("[INFO] Loading the model...")
14
- model_name = "laion/larger_clap_general"
15
  model = ClapModel.from_pretrained(model_name)
16
  processor = ClapProcessor.from_pretrained(model_name)
17
 
18
  # Gradio Interface #################################################################################
19
  max_results = 10
20
 
21
-
22
  def sound_search(query):
23
  text_inputs = processor(text=query, return_tensors="pt")
24
  text_embed = model.get_text_features(**text_inputs)[0]
25
 
26
  hits = client.search(
27
  collection_name="demo_spaces_db",
28
- query_vector=text_embed,
29
  limit=max_results,
30
  )
31
  return [
32
  gr.Audio(
33
- hit.payload['audio_path'],
34
  label=f"style: {hit.payload['style']} -- score: {hit.score}")
35
  for hit in hits
36
  ]
37
 
38
-
39
  with gr.Blocks() as demo:
40
  gr.Markdown(
41
  """# Sound search database """
42
  )
43
  inp = gr.Textbox(placeholder="What sound are you looking for ?")
44
  out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Necessary to have different objs
45
- inp.change(sound_search, inp, out)
46
 
47
  demo.launch()
 
1
  import os
 
2
  import gradio as gr
3
  from qdrant_client import QdrantClient
4
  from transformers import ClapModel, ClapProcessor
5
+
6
+ # Retrieve Qdrant URL and API key from environment variables
7
  QDRANT_URL = os.getenv("QDRANT_URL")
8
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
9
+
10
+ if not QDRANT_URL or not QDRANT_API_KEY:
11
+ raise ValueError("Please set the QDRANT_URL and QDRANT_API_KEY environment variables")
12
+
13
  # Loading the Qdrant DB in local ###################################################################
14
+ client = QdrantClient(QDRANT_URL, api_key=QDRANT_API_KEY)
15
  print("[INFO] Client created...")
16
 
17
  # loading the model
18
  print("[INFO] Loading the model...")
19
+ model_name = "laion/clap-large-v2"
20
  model = ClapModel.from_pretrained(model_name)
21
  processor = ClapProcessor.from_pretrained(model_name)
22
 
23
  # Gradio Interface #################################################################################
24
  max_results = 10
25
 
 
26
  def sound_search(query):
27
  text_inputs = processor(text=query, return_tensors="pt")
28
  text_embed = model.get_text_features(**text_inputs)[0]
29
 
30
  hits = client.search(
31
  collection_name="demo_spaces_db",
32
+ query_vector=text_embed.tolist(), # Convert tensor to list
33
  limit=max_results,
34
  )
35
  return [
36
  gr.Audio(
37
+ value=hit.payload['audio_path'],
38
  label=f"style: {hit.payload['style']} -- score: {hit.score}")
39
  for hit in hits
40
  ]
41
 
 
42
  with gr.Blocks() as demo:
43
  gr.Markdown(
44
  """# Sound search database """
45
  )
46
  inp = gr.Textbox(placeholder="What sound are you looking for ?")
47
  out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Necessary to have different objs
48
+ inp.change(fn=sound_search, inputs=inp, outputs=out)
49
 
50
  demo.launch()