BERT_MODEL / main.py
Sp2503's picture
Update main.py
e21572d verified
import os
import pandas as pd
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from pymongo import MongoClient
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import Dict, List, Optional
# --- Configuration ---
LOCAL_MODEL_PATH = "./"
MONGO_URI = os.getenv("MONGO_URI")
DB_NAME = "legal_aid-chatbot"
COLLECTION_NAME = "categories"
# --- Global Resources ---
model: Optional[AutoModelForSequenceClassification] = None
tokenizer: Optional[AutoTokenizer] = None
intent_map: Dict[int, str] = {}
collection: Optional[MongoClient] = None
def create_intent_map(csv_files: List[str]) -> Dict[int, str]:
"""Creates a consistent intent-to-ID mapping from the training CSVs."""
all_intents = set()
try:
df_women = pd.read_csv(csv_files[0])
all_intents.update(df_women['intent'].unique())
df_legal_aid = pd.read_csv(csv_files[1])
all_intents.update(df_legal_aid['intent_type'].unique())
except FileNotFoundError as e:
print(f"❌ Critical Error: CSV for intent mapping not found: {e}")
return {}
return {i: intent for i, intent in enumerate(sorted(list(all_intents)))}
app = FastAPI(title="Legal Aid Chatbot API", version="1.0.0")
@app.on_event("startup")
def startup_event():
"""Loads all necessary resources when the FastAPI application starts."""
global model, tokenizer, intent_map, collection
print("--- Loading resources on application startup ---")
if not MONGO_URI:
print("❌ Critical Error: MONGO_URI secret is not set in Hugging Face Space settings.")
return
intent_map = create_intent_map(['womens_legal_questions_20k.csv', 'legal_aid_chatbot_dataset_20k.csv'])
if not intent_map:
print("❌ Could not create intent map. API will not function correctly.")
return
try:
tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(LOCAL_MODEL_PATH)
print(f"βœ… Model and tokenizer loaded successfully from '{LOCAL_MODEL_PATH}'.")
except Exception as e:
print(f"❌ Critical Error loading model from local directory: {e}")
model = None
try:
client = MongoClient(MONGO_URI)
collection = client[DB_NAME][COLLECTION_NAME]
client.server_info()
print("πŸš€ Successfully connected to MongoDB.")
except Exception as e:
print(f"❌ Critical Error connecting to MongoDB: {e}")
collection = None
class QueryRequest(BaseModel):
question: str
class SolutionResponse(BaseModel):
predicted_intent: str
solution: str
@app.post("/get-solution", response_model=SolutionResponse)
def get_legal_solution(request: QueryRequest):
if model is None or tokenizer is None or collection is None:
raise HTTPException(status_code=503, detail="Server resources are not ready. Check startup logs for errors.")
inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
prediction_id = torch.argmax(logits, dim=1).item()
# ** FIX: Use the intent_map created from CSVs to get the human-readable label **
predicted_intent = intent_map.get(prediction_id, "Unknown Intent")
document = collection.find_one({"intent": predicted_intent})
solution = document["answer"] if document and "answer" in document else "No specific solution was found for this topic."
return SolutionResponse(predicted_intent=predicted_intent, solution=solution)
@app.get("/")
def root():
return {"message": "Legal Aid Chatbot API is active and running."}