Spaces:
Runtime error
Runtime error
BioMike BioMikeUkr
Model replaced with knowledgator/gliner-multitask-v1.0 and ner task (Text Classification added)
b03b2db | from gliner import GLiNER | |
| import gradio as gr | |
| model = GLiNER.from_pretrained("knowledgator/gliner-multitask-v1.0").to("cpu") | |
| PROMPT_TEMPLATE = """Classify the given text having the following classes: {}""" | |
| classification_examples = [ | |
| [ | |
| "The sun is shining and the weather is warm today.", | |
| "Weather, Food, Technology", | |
| 0.5 | |
| ], | |
| [ | |
| "I really enjoyed the pizza we had for dinner last night.", | |
| "Food, Weather, Sports", | |
| 0.5 | |
| ], | |
| [ | |
| "Das Kind spielt im Park und genießt die frische Luft.", | |
| "Nature, Technology, Politics", | |
| 0.5 | |
| ] | |
| ] | |
| def prepare_prompts(text, labels): | |
| labels_str = ', '.join(labels) | |
| return PROMPT_TEMPLATE.format(labels_str) + "\n" + text | |
| def process(text, labels, threshold): | |
| if not text.strip() or not labels.strip(): | |
| return {"text": text, "entities": []} | |
| labels = [label.strip() for label in labels.split(",")] | |
| prompt = prepare_prompts(text, labels) | |
| predictions = model.run([prompt], ["match"], threshold=threshold) | |
| entities = [] | |
| if predictions and predictions[0]: | |
| for pred in predictions[0]: | |
| entities.append({ | |
| "entity": "match", | |
| "word": pred["text"], | |
| "start": pred["start"], | |
| "end": pred["end"], | |
| "score": pred["score"] | |
| }) | |
| return {"text": prompt, "entities": entities} | |
| with gr.Blocks(title="Text Classification with Highlighted Labels") as classification_interface: | |
| gr.Markdown("# Text Classification with Highlighted Labels") | |
| input_text = gr.Textbox(label="Input Text", placeholder="Enter text for classification") | |
| input_labels = gr.Textbox(label="Labels (Comma-Separated)", placeholder="Enter labels separated by commas (e.g., Positive, Negative, Neutral)") | |
| threshold = gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold") | |
| output = gr.HighlightedText(label="Classification Results") | |
| submit_btn = gr.Button("Classify") | |
| examples = gr.Examples( | |
| examples=classification_examples, | |
| inputs=[input_text, input_labels, threshold], | |
| outputs=output, | |
| fn=process, | |
| cache_examples=True | |
| ) | |
| theme=gr.themes.Base() | |
| input_text.submit(fn=process, inputs=[input_text, input_labels, threshold], outputs=outpu) | |
| threshold.release(fn=process, inputs=[input_text, input_labels, threshold], outputs=outpu) | |
| submit_btn.click(fn=process, inputs=[input_text, input_labels, threshold], outputs=output) | |
| if __name__ == "__main__": | |
| classification_interface.launch() |