Spaces:
Sleeping
Sleeping
File size: 1,903 Bytes
4bea9cd a2464d3 ccd41a5 8630f65 a2464d3 8630f65 4bea9cd 8630f65 ebb8768 a2464d3 4bea9cd 8630f65 a2464d3 8630f65 a2464d3 8630f65 cb5b1d2 8630f65 17205ab 8630f65 cb5b1d2 17205ab cb5b1d2 a2464d3 8630f65 cb5b1d2 8630f65 4bea9cd cb5b1d2 8630f65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
# --- Configuration ---
FINAL_MODEL_PATH = './final_bert_model_pdf'
SOLUTIONS_DATASET_PATH = 'qa_dataset_detailed_answers.csv'
# --- Load Models and Data ---
def load_resources():
try:
tokenizer = AutoTokenizer.from_pretrained(FINAL_MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(FINAL_MODEL_PATH)
solutions_df = pd.read_csv(SOLUTIONS_DATASET_PATH)
solution_database = solutions_df.set_index('Intent')['Answer'].to_dict()
print("✅ Resources loaded successfully!")
return model, tokenizer, solution_database
except Exception as e:
print(f"❌ Critical Error loading resources: {e}")
return None, None, None
model, tokenizer, solution_database = load_resources()
# --- Initialize FastAPI ---
app = FastAPI(title="Legal Aid API")
# --- API Data Models ---
class QueryRequest(BaseModel):
question: str
class SolutionResponse(BaseModel):
predicted_intent: str
solution: str
# --- API Endpoints ---
@app.post("/get-solution", response_model=SolutionResponse)
def get_legal_solution(request: QueryRequest):
if not model:
return {"predicted_intent": "Error", "solution": "Model not loaded."}
inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
prediction_id = torch.argmax(logits, dim=1).item()
predicted_intent = model.config.id2label[prediction_id]
solution = solution_database.get(predicted_intent, "No solution found.")
return {"predicted_intent": predicted_intent, "solution": solution}
@app.get("/")
def read_root():
return {"status": "Legal Aid API is running."} |