| from fastapi import FastAPI, Request, Form |
| from fastapi.responses import HTMLResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.templating import Jinja2Templates |
| from pydantic import BaseModel |
| from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification |
| import numpy as np |
|
|
| app = FastAPI() |
|
|
| |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
| |
| templates = Jinja2Templates(directory="templates") |
|
|
| |
| model_name = "IsmatS/xlm-roberta-az-ner" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForTokenClassification.from_pretrained(model_name) |
| nlp_ner = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") |
|
|
| label_mapping = { |
| "LABEL_0": "Other", |
| "LABEL_1": "Person", |
| "LABEL_2": "Location", |
| "LABEL_3": "Organization", |
| "LABEL_4": "Date", |
| "LABEL_5": "Time", |
| "LABEL_6": "Money", |
| "LABEL_7": "Percentage", |
| "LABEL_8": "Facility", |
| "LABEL_9": "Product", |
| "LABEL_10": "Event", |
| "LABEL_11": "Art", |
| "LABEL_12": "Law", |
| "LABEL_13": "Language", |
| "LABEL_14": "Government", |
| "LABEL_15": "Nationality or Religion", |
| "LABEL_16": "Ordinal", |
| "LABEL_17": "Cardinal", |
| "LABEL_18": "Disease", |
| "LABEL_19": "Contact", |
| "LABEL_20": "Proverb or Saying", |
| "LABEL_21": "Quantity", |
| "LABEL_22": "Miscellaneous", |
| "LABEL_23": "Position", |
| "LABEL_24": "Project" |
| } |
|
|
| def convert_numpy_types(obj): |
| if isinstance(obj, np.float32): |
| return float(obj) |
| elif isinstance(obj, np.int32): |
| return int(obj) |
| elif isinstance(obj, list): |
| return [convert_numpy_types(item) for item in obj] |
| elif isinstance(obj, dict): |
| return {key: convert_numpy_types(value) for key, value in obj.items()} |
| else: |
| return obj |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def index(request: Request): |
| return templates.TemplateResponse("index.html", {"request": request}) |
|
|
| @app.post("/predict/") |
| async def predict_ner(text: str = Form(...)): |
| ner_results = nlp_ner(text) |
| |
| |
| entities_by_type = {} |
|
|
| |
| for entity in ner_results: |
| |
| entity_type = label_mapping.get(entity["entity_group"], entity["entity_group"]) |
| |
| |
| if entity_type == "Other": |
| continue |
| |
| |
| if entity_type not in entities_by_type: |
| entities_by_type[entity_type] = [] |
| |
| |
| entities_by_type[entity_type].append(entity["word"]) |
|
|
| return {"entities": entities_by_type} |
|
|
|
|
| |
| |
| |
| |
|
|
| |