Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from transformers import pipeline | |
| import uvicorn | |
| import streamlit as st | |
| # Load trained model | |
| model_name = "DINGOLANI/distilbert-ner-v2" | |
| try: | |
| nlp_ner = pipeline("token-classification", model=model_name, tokenizer=model_name) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {e}") | |
| # Corrected label mapping based on expected training labels | |
| label_map = { | |
| "LABEL_1": "B-BRAND", | |
| "LABEL_2": "I-BRAND", | |
| "LABEL_3": "B-CATEGORY", | |
| "LABEL_4": "I-CATEGORY", | |
| "LABEL_5": "B-GENDER", | |
| "LABEL_6": "B-PRICE", | |
| "LABEL_7": "I-PRICE" | |
| } | |
| entity_filter = { | |
| "B-BRAND": "BRAND", | |
| "I-BRAND": "BRAND", | |
| "B-CATEGORY": "CATEGORY", | |
| "I-CATEGORY": "CATEGORY", | |
| "B-GENDER": "GENDER", | |
| "B-PRICE": "PRICE", | |
| "I-PRICE": "PRICE" | |
| } | |
| app = FastAPI() | |
| def home(): | |
| return {"message": "NER API is running!"} | |
| def predict(query: str): | |
| try: | |
| result = nlp_ner(query) | |
| for label in result: | |
| label["score"] = float(label["score"]) | |
| print("RAW MODEL OUTPUT:", result) | |
| structured_output = {} | |
| prev_label = None | |
| prev_word = None | |
| for label in result: | |
| entity_bio = label_map.get(label.get("entity")) | |
| entity = entity_filter.get(entity_bio) | |
| if entity: | |
| word = label["word"] | |
| if word.startswith("##"): | |
| if prev_label == entity and prev_word: | |
| structured_output[entity][-1] += word[2:] | |
| else: | |
| structured_output.setdefault(entity, []).append(word[2:]) | |
| else: | |
| structured_output.setdefault(entity, []).append(word) | |
| prev_label = entity | |
| prev_word = word | |
| return { | |
| "query": query, | |
| "raw_output": result, | |
| "structured_output": structured_output | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {e}") | |
| # π Streamlit Frontend | |
| def main(): | |
| st.set_page_config(page_title="Luxury Fashion NER", layout="wide") | |
| st.title("π Luxury Fashion Entity Extractor") | |
| st.write("Enter a text query and extract structured entities like **Brand, Category, Gender, and Price.**") | |
| query = st.text_input("Enter Query:", "Gucci handbags for women under $5000") | |
| if st.button("Analyze"): | |
| response = predict(query) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("π Structured Output") | |
| for key, value in response["structured_output"].items(): | |
| st.write(f"**{key}:** {', '.join(value)}") | |
| with col2: | |
| st.subheader("π Raw Model Output") | |
| st.json(response["raw_output"]) | |
| if __name__ == "__main__": | |
| main() | |