Spaces:
Sleeping
Sleeping
| import json, time, csv, os | |
| import gradio as gr | |
| from transformers import pipeline | |
| # ββββββββββββββββ | |
| # Load taxonomies | |
| # ββββββββββββββββ | |
| with open("coarse_labels.json") as f: | |
| coarse_labels = json.load(f) | |
| with open("fine_labels.json") as f: | |
| fine_map = json.load(f) | |
| # ββββββββββββββββ | |
| # Model choices (5 only) | |
| # ββββββββββββββββ | |
| MODEL_CHOICES = [ | |
| "facebook/bart-large-mnli", | |
| "roberta-large-mnli", | |
| "joeddav/xlm-roberta-large-xnli", | |
| "valhalla/distilbart-mnli-12-4", | |
| "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" , | |
| "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"# placeholder β replace with real phantom model | |
| ] | |
| PIPELINES = {} | |
| def get_pipeline(name): | |
| if name not in PIPELINES: | |
| PIPELINES[name] = pipeline("zero-shot-classification", model=name) | |
| return PIPELINES[name] | |
| # ββββββββββββββββ | |
| # Ensure log files exist | |
| # ββββββββββββββββ | |
| LOG_FILE = "logs.csv" | |
| FEEDBACK_FILE = "feedback.csv" | |
| for fn, hdr in [ | |
| (LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]), | |
| (FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"]) | |
| ]: | |
| if not os.path.exists(fn): | |
| with open(fn, "w", newline="") as f: | |
| csv.writer(f).writerow(hdr) | |
| # ββββββββββββββββ | |
| # Inference functions | |
| # ββββββββββββββββ | |
| def run_stage1(question, model_name): | |
| if not question or not question.strip(): | |
| return {}, gr.update(choices=[]), "" | |
| start = time.time() | |
| clf = get_pipeline(model_name) | |
| out = clf(question, candidate_labels=coarse_labels) | |
| labels, scores = out["labels"][:3], out["scores"][:3] | |
| duration = round(time.time() - start, 3) | |
| # Prepare outputs | |
| subject_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)} | |
| radio_update = gr.update(choices=labels, value=labels[0]) | |
| time_str = f"β± {duration}s" | |
| return subject_dict, radio_update, time_str | |
| def run_stage2(question, model_name, subject): | |
| # 1) Validate inputs | |
| if not question or not question.strip(): | |
| return {}, "No question provided", "" | |
| fine_labels = fine_map.get(subject, []) | |
| if not fine_labels: | |
| return {}, f"No topics found for '{subject}'", "" | |
| # 2) Inference (fast, using preloaded pipeline) | |
| start = time.time() | |
| clf = get_pipeline(model_name) | |
| out = clf(question, candidate_labels=fine_labels) | |
| labels, scores = out["labels"][:3], out["scores"][:3] | |
| duration = round(time.time() - start, 3) | |
| # 3) Logging | |
| with open(LOG_FILE, "a", newline="") as f: | |
| csv.writer(f).writerow([ | |
| time.strftime("%Y-%m-%d %H:%M:%S"), | |
| model_name, | |
| question.replace("\n"," "), | |
| subject, | |
| ";".join(labels), | |
| duration | |
| ]) | |
| # 4) Return topics + time | |
| topic_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)} | |
| return topic_dict, f"β± {duration}s" | |
| def submit_feedback(question, subject_fb, topic_fb): | |
| with open(FEEDBACK_FILE, "a", newline="") as f: | |
| csv.writer(f).writerow([ | |
| time.strftime("%Y-%m-%d %H:%M:%S"), | |
| question.replace("\n"," "), | |
| subject_fb, | |
| topic_fb | |
| ]) | |
| return "β Feedback recorded!" | |
| # ββββββββββββββββ | |
| # Build Gradio UI | |
| # ββββββββββββββββ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback") | |
| with gr.Row(): | |
| question_input = gr.Textbox(lines=3, label="Enter your question") | |
| model_input = gr.Dropdown(choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model") | |
| go_button = gr.Button("Run Stage 1") | |
| subject_out = gr.Label(num_top_classes=3, label="Top-3 Subjects") | |
| subj_radio = gr.Radio(choices=[], label="Select Subject for Stage 2") | |
| stage1_time = gr.Textbox(label="Stage 1 Time") | |
| go_button.click( | |
| fn=run_stage1, | |
| inputs=[question_input, model_input], | |
| outputs=[subject_out, subj_radio, stage1_time] | |
| ) | |
| # Stage 2 UI | |
| go2_button = gr.Button("Run Stage 2") | |
| topics_out = gr.Label(label="Top-3 Topics") | |
| stage2_time = gr.Textbox(label="Stage 2 Time") | |
| go2_button.click( | |
| fn=run_stage2, | |
| inputs=[question_input, model_input, subj_radio], | |
| outputs=[topics_out, stage2_time] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### Feedback / Correction") | |
| subject_fb = gr.Textbox(label="Correct Subject") | |
| topic_fb = gr.Textbox(label="Correct Topic(s)") | |
| fb_button = gr.Button("Submit Feedback") | |
| fb_status = gr.Textbox(label="") | |
| fb_button.click( | |
| fn=submit_feedback, | |
| inputs=[question_input, subject_fb, topic_fb], | |
| outputs=[fb_status] | |
| ) | |
| demo.launch(share=True) | |