from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from transformers import BertForSequenceClassification, BertTokenizer import torch import os app = FastAPI(title="Intent Classifier API", description="BERT-based intent classification system") # Get the absolute path to the model directory BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # go back one level to get the correct path BASE_DIR = os.path.dirname(BASE_DIR) MODEL_DIR = os.path.join(BASE_DIR, "intent_classifier_model") TOKENIZER_DIR = os.path.join(BASE_DIR, "intent_classifier_tokenizer") # Ensure model and tokenizer directories exist if not os.path.isdir(MODEL_DIR): raise FileNotFoundError(f"Model directory not found: {MODEL_DIR}") if not os.path.isdir(TOKENIZER_DIR): raise FileNotFoundError(f"Tokenizer directory not found: {TOKENIZER_DIR}") # Load model and tokenizer from local directories only model = BertForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True) tokenizer = BertTokenizer.from_pretrained(TOKENIZER_DIR, local_files_only=True) # Complete CLINC150 intent labels in exact order (151 total) INTENT_LABELS = ['restaurant_reviews', 'nutrition_info', 'account_blocked', 'oil_change_how', 'time', 'weather', 'redeem_rewards', 'interest_rate', 'gas_type', 'accept_reservations', 'smart_home', 'user_name', 'report_lost_card', 'repeat', 'whisper_mode', 'what_are_your_hobbies', 'order', 'jump_start', 'schedule_meeting', 'meeting_schedule', 'freeze_account', 'what_song', 'meaning_of_life', 'restaurant_reservation', 'traffic', 'make_call', 'text', 'bill_balance', 'improve_credit_score', 'change_language', 'no', 'measurement_conversion', 'timer', 'flip_coin', 'do_you_have_pets', 'balance', 'tell_joke', 'last_maintenance', 'exchange_rate', 'uber', 'car_rental', 'credit_limit', 'oos', 'shopping_list', 'expiration_date', 'routing', 'meal_suggestion', 'tire_change', 'todo_list', 'card_declined', 'rewards_balance', 'change_accent', 'vaccines', 'reminder_update', 'food_last', 'change_ai_name', 'bill_due', 'who_do_you_work_for', 'share_location', 'international_visa', 'calendar', 'translate', 'carry_on', 'book_flight', 'insurance_change', 'todo_list_update', 'timezone', 'cancel_reservation', 'transactions', 'credit_score', 'report_fraud', 'spending_history', 'directions', 'spelling', 'insurance', 'what_is_your_name', 'reminder', 'where_are_you_from', 'distance', 'payday', 'flight_status', 'find_phone', 'greeting', 'alarm', 'order_status', 'confirm_reservation', 'cook_time', 'damaged_card', 'reset_settings', 'pin_change', 'replacement_card_duration', 'new_card', 'roll_dice', 'income', 'taxes', 'date', 'who_made_you', 'pto_request', 'tire_pressure', 'how_old_are_you', 'rollover_401k', 'pto_request_status', 'how_busy', 'application_status', 'recipe', 'calendar_update', 'play_music', 'yes', 'direct_deposit', 'credit_limit_change', 'gas', 'pay_bill', 'ingredients_list', 'lost_luggage', 'goodbye', 'what_can_i_ask_you', 'book_hotel', 'are_you_a_bot', 'next_song', 'change_speed', 'plug_type', 'maybe', 'w2', 'oil_change_when', 'thank_you', 'shopping_list_update', 'pto_balance', 'order_checks', 'travel_alert', 'fun_fact', 'sync_device', 'schedule_maintenance', 'apr', 'transfer', 'ingredient_substitution', 'calories', 'current_location', 'international_fees', 'calculator', 'definition', 'next_holiday', 'update_playlist', 'mpg', 'min_payment', 'change_user_name', 'restaurant_suggestion', 'travel_notification', 'cancel', 'pto_used', 'travel_suggestion', 'change_volume'] def int2str(idx): return INTENT_LABELS[idx] if 0 <= idx < len(INTENT_LABELS) else "unknown" class Query(BaseModel): text: str = None message: str = None # Add compatibility endpoint for both 'message' and 'text' fields @app.post("/predict") def predict_intent_compat(request: Query): """Compatibility endpoint that handles both text and message fields""" try: # Handle both 'text' and 'message' fields for compatibility text = request.message or request.text or "" if not text: return {"error": "No text or message provided"} inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) prediction = outputs.logits.argmax(dim=-1).item() # Debug information print(f"Input: {text}") print(f"Raw prediction index: {prediction}") print(f"Total labels available: {len(INTENT_LABELS)}") intent = int2str(prediction) print(f"Mapped intent: {intent}") if intent == "oos": return {"intent": "out of scope (OOS)"} else: intent = intent.replace("_", " ").title() return {"intent": intent} except Exception as e: print(f"Error in prediction: {e}") return {"intent": "Error", "error": str(e)} @app.get("/", response_class=HTMLResponse) async def read_root(): """Serve the main HTML interface""" html_content = """