File size: 1,903 Bytes
4bea9cd
a2464d3
ccd41a5
 
8630f65
 
 
a2464d3
8630f65
 
4bea9cd
 
8630f65
ebb8768
a2464d3
4bea9cd
 
8630f65
 
 
 
a2464d3
8630f65
 
a2464d3
8630f65
cb5b1d2
8630f65
 
17205ab
8630f65
cb5b1d2
 
17205ab
cb5b1d2
 
 
a2464d3
8630f65
cb5b1d2
 
8630f65
 
 
 
 
 
 
 
 
 
 
4bea9cd
cb5b1d2
8630f65
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

import os
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd

# --- Configuration ---
FINAL_MODEL_PATH = './final_bert_model_pdf' 
SOLUTIONS_DATASET_PATH = 'qa_dataset_detailed_answers.csv'

# --- Load Models and Data ---
def load_resources():
    try:
        tokenizer = AutoTokenizer.from_pretrained(FINAL_MODEL_PATH)
        model = AutoModelForSequenceClassification.from_pretrained(FINAL_MODEL_PATH)
        solutions_df = pd.read_csv(SOLUTIONS_DATASET_PATH)
        solution_database = solutions_df.set_index('Intent')['Answer'].to_dict()
        print("✅ Resources loaded successfully!")
        return model, tokenizer, solution_database
    except Exception as e:
        print(f"❌ Critical Error loading resources: {e}")
        return None, None, None

model, tokenizer, solution_database = load_resources()

# --- Initialize FastAPI ---
app = FastAPI(title="Legal Aid API")

# --- API Data Models ---
class QueryRequest(BaseModel):
    question: str

class SolutionResponse(BaseModel):
    predicted_intent: str
    solution: str

# --- API Endpoints ---
@app.post("/get-solution", response_model=SolutionResponse)
def get_legal_solution(request: QueryRequest):
    if not model:
        return {"predicted_intent": "Error", "solution": "Model not loaded."}

    inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        logits = model(**inputs).logits
    prediction_id = torch.argmax(logits, dim=1).item()
    predicted_intent = model.config.id2label[prediction_id]
    solution = solution_database.get(predicted_intent, "No solution found.")
    
    return {"predicted_intent": predicted_intent, "solution": solution}

@app.get("/")
def read_root():
    return {"status": "Legal Aid API is running."}