Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| os.system("pip install setfit") | |
| from setfit import SetFitModel | |
| default_hf_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface") | |
| HF_HOME = os.environ.get("HF_HOME", default_hf_home) | |
| coloridentity_model = "joshuasundance/mtg-coloridentity-multilabel-classification" | |
| labels = ["black", "green", "red", "blue", "white"] | |
| model = SetFitModel.from_pretrained(coloridentity_model, cache_dir=HF_HOME) | |
| def get_preds(input_text: str) -> tuple[str, dict[str, float]]: | |
| preds = model.predict_proba(input_text) | |
| pred_dict = {label: preds[i] for i, label in enumerate(labels)} | |
| color_identity = "/".join([color for i, color in enumerate(labels) if preds[i] > 0.5]) | |
| if color_identity == "": | |
| color_identity = "colorless" | |
| return color_identity, pred_dict | |
| iface = gr.Interface( | |
| fn=get_preds, | |
| inputs=gr.Textbox(), | |
| outputs=[ | |
| gr.Textbox(), | |
| gr.Label(), | |
| ], | |
| title="Magic the Gathering Color Identity Classifier", | |
| description="Enter card name and ability text to classify the color identity of the card.", | |
| allow_flagging=False, | |
| ) | |
| iface.launch(show_api=True) |