Spaces:
Running
Running
| import logging | |
| import pathlib | |
| import gradio as gr | |
| import pandas as pd | |
| from gt4sd.algorithms.conditional_generation.key_bert import ( | |
| KeywordBERTGenerationAlgorithm, | |
| KeyBERTGenerator, | |
| ) | |
| from gt4sd.algorithms.registry import ApplicationsRegistry | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| def run_inference( | |
| algorithm_version: str, | |
| text: str, | |
| minimum_keyphrase_ngram: int, | |
| maximum_keyphrase_ngram: int, | |
| stop_words: str, | |
| use_maxsum: bool, | |
| number_of_candidates: int, | |
| use_mmr: bool, | |
| diversity: float, | |
| number_of_keywords: int, | |
| ): | |
| config = KeyBERTGenerator( | |
| algorithm_version=algorithm_version, | |
| minimum_keyphrase_ngram=minimum_keyphrase_ngram, | |
| maximum_keyphrase_ngram=maximum_keyphrase_ngram, | |
| stop_words=stop_words, | |
| top_n=number_of_keywords, | |
| use_maxsum=use_maxsum, | |
| use_mmr=use_mmr, | |
| diversity=diversity, | |
| number_of_candidates=number_of_candidates, | |
| ) | |
| model = KeywordBERTGenerationAlgorithm(configuration=config, target=text) | |
| text = list(model.sample(number_of_keywords)) | |
| return text | |
| if __name__ == "__main__": | |
| # Preparation (retrieve all available algorithms) | |
| all_algos = ApplicationsRegistry.list_available() | |
| algos = [ | |
| x["algorithm_version"] | |
| for x in list(filter(lambda x: "KeywordBERT" in x["algorithm_name"], all_algos)) | |
| ] | |
| # Load metadata | |
| metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") | |
| examples = pd.read_csv( | |
| metadata_root.joinpath("examples.csv"), sep=",", header=None | |
| ).fillna("") | |
| with open(metadata_root.joinpath("article.md"), "r") as f: | |
| article = f.read() | |
| with open(metadata_root.joinpath("description.md"), "r") as f: | |
| description = f.read() | |
| demo = gr.Interface( | |
| fn=run_inference, | |
| title="KeywordBERT", | |
| inputs=[ | |
| gr.Dropdown(algos, label="Algorithm version", value="circa_bert_v2"), | |
| gr.Textbox( | |
| label="Text prompt", | |
| placeholder="This is a text I want to understand better", | |
| lines=5, | |
| ), | |
| gr.Slider( | |
| minimum=1, maximum=5, value=1, label="Minimum keyphrase ngram", step=1 | |
| ), | |
| gr.Slider( | |
| minimum=2, maximum=10, value=1, label="Maximum keyphrase ngram", step=1 | |
| ), | |
| gr.Textbox(label="Stop words", placeholder="english", lines=1), | |
| gr.Radio(choices=[True, False], label="MaxSum", value=False), | |
| gr.Slider( | |
| minimum=5, maximum=100, value=20, label="MaxSum candidates", step=1 | |
| ), | |
| gr.Radio( | |
| choices=[True, False], | |
| label="Max. marginal relevance control", | |
| value=False, | |
| ), | |
| gr.Slider(minimum=0.1, maximum=1, value=0.5, label="Diversity"), | |
| gr.Slider( | |
| minimum=1, maximum=50, value=10, label="Number of keywords", step=1 | |
| ), | |
| ], | |
| outputs=gr.Textbox(label="Output"), | |
| article=article, | |
| description=description, | |
| examples=examples.values.tolist(), | |
| ) | |
| demo.launch(debug=True, show_error=True) | |