Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| MODEL_ID = "techysanoj/fine-tuned-IndicNER" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForTokenClassification.from_pretrained(MODEL_ID) | |
| id2label = {int(k): v for k, v in model.config.id2label.items()} | |
| # Color map for Gradio HTML output | |
| COLOR_MAP = { | |
| "B-PER": "red", | |
| "I-PER": "red", | |
| "B-ORG": "green", | |
| "I-ORG": "green", | |
| "B-LOC": "blue", | |
| "I-LOC": "blue", | |
| "O": "black" | |
| } | |
| def generate_ner_output(text): | |
| if not text.strip(): | |
| return "Please enter valid input." | |
| inputs = tokenizer(text, return_tensors="pt") | |
| token_ids = inputs["input_ids"][0] | |
| tokens = tokenizer.convert_ids_to_tokens(token_ids) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # Softmax for confidence | |
| probs = F.softmax(logits, dim=-1)[0] | |
| pred_ids = torch.argmax(probs, dim=-1).tolist() | |
| html_output = "<div style='font-family: monospace; font-size: 18px;'>" | |
| for tok, pid, prob_vec in zip(tokens, pred_ids, probs): | |
| label = id2label[pid] | |
| conf = float(prob_vec[pid]) | |
| color = COLOR_MAP[label] | |
| html_output += ( | |
| f"<span style='color:{color}; font-weight:bold;'>" | |
| f"{tok.replace(' ', ' ')}</span>" | |
| f" β <span style='color:{color};'><b>{label}</b></span>" | |
| f" (conf: {conf:.3f})<br>" | |
| ) | |
| html_output += "</div>" | |
| return html_output | |
| # ---------- GRADIO UI ------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π₯ IndicNER β Token-Level NER (Colored + Confidence)") | |
| text_input = gr.Textbox(label="Enter text", lines=3, placeholder="Type sentence here...") | |
| run_btn = gr.Button("Generate NER") | |
| ner_html = gr.HTML(label="NER Output") | |
| run_btn.click(fn=generate_ner_output, inputs=text_input, outputs=ner_html) | |
| demo.launch() | |