LTEnjoy commited on
Commit
f0b07c6
·
verified ·
1 Parent(s): 3f7cec9

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +83 -0
  2. loop_retrieve_cards.py +56 -0
  3. utils.py +89 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ root_dir = __file__.rsplit("/", 2)[0]
4
+ if root_dir not in sys.path:
5
+ sys.path.append(root_dir)
6
+
7
+ import gradio as gr
8
+
9
+ from utils import set_text_bg_color
10
+ from loop_retrieve_cards import get_models, get_datasets, get_readme_dict
11
+
12
+
13
+ def match_card(input: str, card_id: str, card_type: str) -> str:
14
+ """
15
+ Search the input in a card. If the input string is contained in the card_id or its README, display this card.
16
+
17
+ Args:
18
+ input: Input string
19
+ card_id: HuggingFace card id
20
+ card_type: Type of card, either "model" or "dataset"
21
+ """
22
+ display_str = ""
23
+ readme_dict = get_readme_dict()
24
+
25
+ if input.lower() in card_id.lower() or input.lower() in readme_dict[card_id].lower():
26
+ # Add card id
27
+ if card_type == "model":
28
+ display_str += f"## [{set_text_bg_color(input, card_id)}](https://huggingface.co/{card_id})\n\n"
29
+ else:
30
+ display_str += f"## [{set_text_bg_color(input, card_id)}](https://huggingface.co/datasets/{card_id})\n\n"
31
+
32
+ # Highlight lines that contain the input string
33
+ show_lines = []
34
+ for line in readme_dict[card_id].split("\n"):
35
+ if input.lower() in line.lower() and "<!--" not in line:
36
+ show_lines.append(set_text_bg_color(input, line))
37
+
38
+ # Add README
39
+ display_str += "\n\n".join(show_lines)
40
+
41
+ # Add a separator
42
+ display_str = f"\n\n{display_str}\n\n---\n\n"
43
+
44
+ # In case that the keyword is only contained in comments
45
+ if input.lower() not in card_id.lower() and len(show_lines) == 0:
46
+ display_str = ""
47
+
48
+ return display_str
49
+
50
+
51
+ def show_card_info(input: str):
52
+ retrieval_str = ""
53
+
54
+ if input != "":
55
+ # Search models
56
+ retrieval_str += "# Models\n\n"
57
+ for model in get_models():
58
+ retrieval_str += match_card(input, model, "model")
59
+
60
+ # Search datasets
61
+ retrieval_str += "# Datasets\n\n"
62
+ for dataset in get_datasets():
63
+ retrieval_str += match_card(input, dataset, "dataset")
64
+
65
+ return gr.Markdown(retrieval_str, visible=True)
66
+
67
+
68
+ # Build demo
69
+ with gr.Blocks(title="SaprotHub", fill_width=True) as demo:
70
+ gr.Label("SaprotHub search", visible=True, show_label=False)
71
+ search_box = gr.Textbox(label="Search box", placeholder="Input keywords to search", interactive=True, scale=0, container=True)
72
+
73
+ # Display search results
74
+ search_hint = gr.Markdown("# Search results:", visible=True)
75
+ items = gr.Markdown(visible=False)
76
+
77
+ # Set events
78
+ search_box.change(show_card_info, inputs=[search_box], outputs=[items])
79
+
80
+
81
+ if __name__ == '__main__':
82
+ # Run the demo
83
+ demo.launch(server_name="0.0.0.0")
loop_retrieve_cards.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+
4
+ from utils import fetch_models, fetch_datasets, fetch_readme
5
+ from tqdm import tqdm
6
+
7
+
8
+ # Define global variables
9
+ models = None
10
+ datasets = None
11
+ readme_dict = {}
12
+
13
+
14
+ # Provide an API to get models
15
+ def get_models():
16
+ return models
17
+
18
+
19
+ # Provide an API to get datasets
20
+ def get_datasets():
21
+ return datasets
22
+
23
+
24
+ # Provide an API to get READMEs
25
+ def get_readme_dict():
26
+ return readme_dict
27
+
28
+
29
+ # Start a thread to continuously update cards
30
+ def run():
31
+ global models, datasets, readme_dict, cnt
32
+
33
+ while True:
34
+ try:
35
+ new_models = fetch_models()
36
+ new_datasets = fetch_datasets()
37
+
38
+ # Add READMEs
39
+ new_readme_dict = {}
40
+ for model in new_models:
41
+ new_readme_dict[model] = fetch_readme(model, "model")
42
+
43
+ for dataset in new_datasets:
44
+ new_readme_dict[dataset] = fetch_readme(dataset, "dataset")
45
+
46
+ # Update global variables
47
+ models = new_models
48
+ datasets = new_datasets
49
+ readme_dict = new_readme_dict
50
+
51
+ except Exception as e:
52
+ print(e)
53
+
54
+
55
+ t = threading.Thread(target=run)
56
+ t.start()
utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import re
3
+
4
+
5
+ def fetch_models(author: str = "SaProtHub") -> list:
6
+ """
7
+ Retrieve models belonging to a specific author
8
+
9
+ Args:
10
+ author: Author name
11
+
12
+ Returns:
13
+ models: List of models
14
+ """
15
+
16
+ url = f"https://hf-mirror.com/api/models?author={author}"
17
+ response = requests.get(url)
18
+ models_dict = response.json()
19
+ models = [item["id"] for item in models_dict]
20
+
21
+ return models
22
+
23
+
24
+ def fetch_datasets(author: str = "SaProtHub") -> list:
25
+ """
26
+ Retrieve datasets belonging to a specific author
27
+
28
+ Args:
29
+ author: Author name
30
+
31
+ Returns:
32
+ datasets: List of datasets
33
+ """
34
+
35
+ url = f"https://hf-mirror.com/api/datasets?author={author}"
36
+ response = requests.get(url)
37
+ datasets_dict = response.json()
38
+ datasets = [item["id"] for item in datasets_dict]
39
+
40
+ return datasets
41
+
42
+
43
+ def fetch_readme(card_id: str, card_type: str) -> str:
44
+ """
45
+ Retrieve the README file of a model or dataset
46
+
47
+ Args:
48
+ card_id: Model or dataset ID
49
+ card_type: Type of card, either "model" or "dataset"
50
+
51
+ Returns:
52
+ readme: README text
53
+ """
54
+ if card_type == "model":
55
+ url = f"https://hf-mirror.com/{card_id}/raw/main/README.md"
56
+ else:
57
+ url = f"https://hf-mirror.com/datasets/{card_id}/raw/main/README.md"
58
+
59
+ response = requests.get(url)
60
+ readme = response.text.split("---")[-1]
61
+
62
+ return readme
63
+
64
+
65
+ def set_text_bg_color(pattern: str, text: str, color: str = "yellow") -> str:
66
+ """
67
+ Set the background color of a pattern in a text
68
+
69
+ Args:
70
+ pattern: Pattern to highlight
71
+ text: Text to search
72
+ color: Background color
73
+
74
+ Returns:
75
+ text: Text with highlighted pattern
76
+ """
77
+
78
+ # Find all matches
79
+ matches = set(re.findall(re.escape(pattern), text, flags=re.IGNORECASE))
80
+ if len(matches) == 0:
81
+ # No matches found
82
+ return text
83
+
84
+ replace_dict = {re.escape(m): f'<span style="background-color:{color}">{m}</span>' for m in matches}
85
+ pattern = re.compile("|".join(replace_dict.keys()))
86
+ text = pattern.sub(lambda m: replace_dict[re.escape(m.group(0))], text)
87
+
88
+ return text
89
+