from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import pipeline from typing import List app = FastAPI(title="NER + Emotion API") # --------------------------------------------------------- # LOAD NER FIRST (PRIORITY LOAD) # --------------------------------------------------------- print("Loading NER model...") ner_pipeline = pipeline( "ner", model="dslim/bert-base-NER", aggregation_strategy="simple" ) print("NER model loaded.") # --------------------------------------------------------- # LOAD SENTIMENT SECOND # --------------------------------------------------------- print("Loading Sentiment model...") sentiment_pipeline = pipeline( "text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1 ) print("Sentiment model loaded.") # --------------------------------------------------------- # REQUEST MODELS # --------------------------------------------------------- class TextInput(BaseModel): text: str class SentimentInput(BaseModel): sentences: List[str] # --------------------------------------------------------- # HEALTH CHECK # --------------------------------------------------------- @app.get("/") def home(): return {"message": "NER + Emotion API is running"} # --------------------------------------------------------- # NER ENDPOINT # --------------------------------------------------------- # --------------------------------------------------------- # NER ENDPOINT (UPDATED) # --------------------------------------------------------- @app.post("/analyze/ner") def analyze_ner(data: TextInput): try: # REMOVED truncation=True to fix the 500 error results = ner_pipeline(data.text, aggregation_strategy="simple") persons = [] locations = [] organizations = [] for entity in results: label = entity["entity_group"] word = entity["word"].strip() # dslim/bert-base-NER uses these labels: if label == "PER": persons.append(word) elif label == "LOC": locations.append(word) elif label == "ORG": organizations.append(word) return { "persons": list(set(persons)), "locations": list(set(locations)), "organizations": list(set(organizations)) } except Exception as e: # This will help you see the exact error in HF logs print(f"Internal Error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # --------------------------------------------------------- # SENTIMENT ENDPOINT # --------------------------------------------------------- @app.post("/analyze/sentiment") def analyze_sentiment(data: SentimentInput): try: results = sentiment_pipeline( data.sentences, truncation=True, max_length=512 ) processed_results = [] for res_list in results: top_result = res_list[0] label = top_result["label"] score = top_result["score"] if label == "joy": polarity = score elif label in ["anger", "disgust", "fear", "sadness"]: polarity = -score else: polarity = 0.0 processed_results.append({ "label": label, "confidence": score, "polarity": polarity }) return {"results": processed_results} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)