Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from pymongo import MongoClient | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from typing import Dict, List, Optional | |
| # --- Configuration --- | |
| LOCAL_MODEL_PATH = "./" | |
| MONGO_URI = os.getenv("MONGO_URI") | |
| DB_NAME = "legal_aid-chatbot" | |
| COLLECTION_NAME = "categories" | |
| # --- Global Resources --- | |
| model: Optional[AutoModelForSequenceClassification] = None | |
| tokenizer: Optional[AutoTokenizer] = None | |
| intent_map: Dict[int, str] = {} | |
| collection: Optional[MongoClient] = None | |
| def create_intent_map(csv_files: List[str]) -> Dict[int, str]: | |
| """Creates a consistent intent-to-ID mapping from the training CSVs.""" | |
| all_intents = set() | |
| try: | |
| df_women = pd.read_csv(csv_files[0]) | |
| all_intents.update(df_women['intent'].unique()) | |
| df_legal_aid = pd.read_csv(csv_files[1]) | |
| all_intents.update(df_legal_aid['intent_type'].unique()) | |
| except FileNotFoundError as e: | |
| print(f"β Critical Error: CSV for intent mapping not found: {e}") | |
| return {} | |
| return {i: intent for i, intent in enumerate(sorted(list(all_intents)))} | |
| app = FastAPI(title="Legal Aid Chatbot API", version="1.0.0") | |
| def startup_event(): | |
| """Loads all necessary resources when the FastAPI application starts.""" | |
| global model, tokenizer, intent_map, collection | |
| print("--- Loading resources on application startup ---") | |
| if not MONGO_URI: | |
| print("β Critical Error: MONGO_URI secret is not set in Hugging Face Space settings.") | |
| return | |
| intent_map = create_intent_map(['womens_legal_questions_20k.csv', 'legal_aid_chatbot_dataset_20k.csv']) | |
| if not intent_map: | |
| print("β Could not create intent map. API will not function correctly.") | |
| return | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH) | |
| model = AutoModelForSequenceClassification.from_pretrained(LOCAL_MODEL_PATH) | |
| print(f"β Model and tokenizer loaded successfully from '{LOCAL_MODEL_PATH}'.") | |
| except Exception as e: | |
| print(f"β Critical Error loading model from local directory: {e}") | |
| model = None | |
| try: | |
| client = MongoClient(MONGO_URI) | |
| collection = client[DB_NAME][COLLECTION_NAME] | |
| client.server_info() | |
| print("π Successfully connected to MongoDB.") | |
| except Exception as e: | |
| print(f"β Critical Error connecting to MongoDB: {e}") | |
| collection = None | |
| class QueryRequest(BaseModel): | |
| question: str | |
| class SolutionResponse(BaseModel): | |
| predicted_intent: str | |
| solution: str | |
| def get_legal_solution(request: QueryRequest): | |
| if model is None or tokenizer is None or collection is None: | |
| raise HTTPException(status_code=503, detail="Server resources are not ready. Check startup logs for errors.") | |
| inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| prediction_id = torch.argmax(logits, dim=1).item() | |
| # ** FIX: Use the intent_map created from CSVs to get the human-readable label ** | |
| predicted_intent = intent_map.get(prediction_id, "Unknown Intent") | |
| document = collection.find_one({"intent": predicted_intent}) | |
| solution = document["answer"] if document and "answer" in document else "No specific solution was found for this topic." | |
| return SolutionResponse(predicted_intent=predicted_intent, solution=solution) | |
| def root(): | |
| return {"message": "Legal Aid Chatbot API is active and running."} |