Sp2503's picture
Update main.py
8630f65 verified
import os
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
# --- Configuration ---
FINAL_MODEL_PATH = './final_bert_model_pdf'
SOLUTIONS_DATASET_PATH = 'qa_dataset_detailed_answers.csv'
# --- Load Models and Data ---
def load_resources():
try:
tokenizer = AutoTokenizer.from_pretrained(FINAL_MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(FINAL_MODEL_PATH)
solutions_df = pd.read_csv(SOLUTIONS_DATASET_PATH)
solution_database = solutions_df.set_index('Intent')['Answer'].to_dict()
print("✅ Resources loaded successfully!")
return model, tokenizer, solution_database
except Exception as e:
print(f"❌ Critical Error loading resources: {e}")
return None, None, None
model, tokenizer, solution_database = load_resources()
# --- Initialize FastAPI ---
app = FastAPI(title="Legal Aid API")
# --- API Data Models ---
class QueryRequest(BaseModel):
question: str
class SolutionResponse(BaseModel):
predicted_intent: str
solution: str
# --- API Endpoints ---
@app.post("/get-solution", response_model=SolutionResponse)
def get_legal_solution(request: QueryRequest):
if not model:
return {"predicted_intent": "Error", "solution": "Model not loaded."}
inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
prediction_id = torch.argmax(logits, dim=1).item()
predicted_intent = model.config.id2label[prediction_id]
solution = solution_database.get(predicted_intent, "No solution found.")
return {"predicted_intent": predicted_intent, "solution": solution}
@app.get("/")
def read_root():
return {"status": "Legal Aid API is running."}