| import os |
| from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification |
| import streamlit as st |
| import torch |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
|
|
| |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| model_path = "dejanseo/DEJAN-Taxonomy-Classifier" |
|
|
| |
| tokenizer = DebertaV2Tokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN) |
| model = DebertaV2ForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN) |
|
|
| |
| LABEL_MAPPING = { |
| 1: 0, 8: 1, 111: 2, 141: 3, 166: 4, 222: 5, 412: 6, 436: 7, |
| 469: 8, 536: 9, 537: 10, 632: 11, 772: 12, 783: 13, 888: 14, |
| 922: 15, 988: 16, 1239: 17, 2092: 18, 5181: 19, 5605: 20 |
| } |
|
|
| CATEGORY_NAMES = { |
| 1: "Animals & Pet Supplies", |
| 8: "Arts & Entertainment", |
| 111: "Business & Industrial", |
| 141: "Cameras & Optics", |
| 166: "Apparel & Accessories", |
| 222: "Electronics", |
| 412: "Food, Beverages & Tobacco", |
| 436: "Furniture", |
| 469: "Health & Beauty", |
| 536: "Home & Garden", |
| 537: "Baby & Toddler", |
| 632: "Hardware", |
| 772: "Mature", |
| 783: "Media", |
| 888: "Vehicles & Parts", |
| 922: "Office Supplies", |
| 988: "Sporting Goods", |
| 1239: "Toys & Games", |
| 2092: "Software", |
| 5181: "Luggage & Bags", |
| 5605: "Religious & Ceremonial" |
| } |
|
|
|
|
| |
| INDEX_TO_CATEGORY = {v: f"[{k}] {CATEGORY_NAMES[k]}" for k, v in LABEL_MAPPING.items()} |
|
|
| |
| st.title("Google Taxonomy Classifier by DEJAN") |
| st.write("Enter text in the input box, and the model will classify it into one of the 21 top level categories. This demo showcases early model capability while the full 5000+ label model is undergoing extensive training.") |
| st.write("Works for product descriptions, search queries, articles, social media posts and broadly web text of any style. Suitable for classification pipelines of millions of queries.") |
|
|
| |
| input_text = st.text_area("Enter text for classification:") |
|
|
| |
| def classify_text(text): |
| if not text.strip(): |
| return None |
| |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits |
| |
| probabilities = F.softmax(logits, dim=-1).squeeze().tolist() |
| return probabilities |
|
|
| |
| if st.button("Classify"): |
| if input_text.strip(): |
| st.write("Processing...") |
| |
| probabilities = classify_text(input_text) |
| if probabilities: |
| |
| mapped_probs = {INDEX_TO_CATEGORY[idx]: prob for idx, prob in enumerate(probabilities)} |
| |
| sorted_categories = sorted(mapped_probs.items(), key=lambda x: x[1], reverse=True) |
| categories = [item[0] for item in sorted_categories] |
| values = [item[1] for item in sorted_categories] |
| |
| |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| ax.barh(categories, values) |
| ax.set_xlabel("Probability") |
| ax.set_ylabel("Category") |
| ax.set_title("Classification Probabilities") |
| ax.invert_yaxis() |
| ax.set_xlim(0, 1) |
| st.pyplot(fig) |
|
|
| |
| st.divider() |
| st.markdown(""" |
| Interested in using this in an automated pipeline for bulk link prediction? |
| Please [book an appointment](https://dejanmarketing.com/conference/) to discuss your needs. |
| """) |
| |
| else: |
| st.error("Could not classify the text. Please try again.") |
| else: |
| st.warning("Please enter some text for classification.") |
|
|