Sp2503 commited on
Commit
8630f65
·
verified ·
1 Parent(s): 45bb406

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +30 -90
main.py CHANGED
@@ -1,116 +1,56 @@
1
- # main.py
2
 
3
  import os
4
- import torch
5
- import pandas as pd
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
9
- from langdetect import detect
 
10
 
11
- # ========== CONFIG ==========
12
- FINAL_MODEL_PATH = './final_bert_model_pdf' # Local fine-tuned English model
13
- MURIL_MODEL_ID = 'Sp2503/Muril-Model' # Hugging Face multilingual model
14
  SOLUTIONS_DATASET_PATH = 'qa_dataset_detailed_answers.csv'
15
- # ✅ Fix cache permissions for Hugging Face Spaces
16
- HF_CACHE_DIR = "/tmp/hf_cache"
17
- os.environ["HF_HOME"] = HF_CACHE_DIR
18
- os.environ["TRANSFORMERS_CACHE"] = HF_CACHE_DIR
19
- os.makedirs(HF_CACHE_DIR, exist_ok=True)
20
 
21
-
22
-
23
- # ========== LOAD MODELS ==========
24
  def load_resources():
25
  try:
26
- # Load English model
27
  tokenizer = AutoTokenizer.from_pretrained(FINAL_MODEL_PATH)
28
  model = AutoModelForSequenceClassification.from_pretrained(FINAL_MODEL_PATH)
29
-
30
- # Load multilingual MuRIL model for non-English
31
- muril_pipeline = pipeline("text-classification", model=MURIL_MODEL_ID)
32
-
33
- # Load answers dataset
34
- df = pd.read_csv(SOLUTIONS_DATASET_PATH)
35
- solution_db = df.set_index('Intent')['Answer'].to_dict()
36
-
37
- print("✅ All models & data loaded successfully!")
38
- return model, tokenizer, muril_pipeline, solution_db
39
  except Exception as e:
40
- print(f"❌ Error loading models or data: {e}")
41
- return None, None, None, None
42
 
43
- model, tokenizer, muril_pipeline, solution_db = load_resources()
44
 
45
- # ========== FASTAPI APP ==========
46
- app = FastAPI(title="AI LegalAid Chatbot Server")
47
 
48
- # Request / Response Schemas
49
  class QueryRequest(BaseModel):
50
  question: str
51
 
52
  class SolutionResponse(BaseModel):
53
  predicted_intent: str
54
  solution: str
55
- model_used: str
56
 
57
- # ========== LOGIC ==========
58
  @app.post("/get-solution", response_model=SolutionResponse)
59
  def get_legal_solution(request: QueryRequest):
60
- if not model or not tokenizer:
61
- return {
62
- "predicted_intent": "Error",
63
- "solution": "Model not loaded properly.",
64
- "model_used": "None"
65
- }
66
-
67
- question = request.question.strip()
68
-
69
- # Detect language
70
- try:
71
- lang = detect(question)
72
- except:
73
- lang = "en"
74
-
75
- # If not English, use MuRIL model
76
- if lang != "en":
77
- try:
78
- muril_result = muril_pipeline(question)
79
- predicted_intent = muril_result[0]['label']
80
- solution = solution_db.get(predicted_intent, "No solution found for this intent.")
81
- return {
82
- "predicted_intent": predicted_intent,
83
- "solution": solution,
84
- "model_used": "MuRIL"
85
- }
86
- except Exception as e:
87
- return {
88
- "predicted_intent": "Error",
89
- "solution": f"MuRIL model failed: {e}",
90
- "model_used": "MuRIL"
91
- }
92
-
93
- # For English questions → use fine-tuned BERT model
94
- try:
95
- inputs = tokenizer(question, return_tensors="pt", truncation=True, padding=True)
96
- with torch.no_grad():
97
- logits = model(**inputs).logits
98
- prediction_id = torch.argmax(logits, dim=1).item()
99
- predicted_intent = model.config.id2label[prediction_id]
100
- solution = solution_db.get(predicted_intent, "No solution found for this intent.")
101
- return {
102
- "predicted_intent": predicted_intent,
103
- "solution": solution,
104
- "model_used": "English BERT"
105
- }
106
- except Exception as e:
107
- return {
108
- "predicted_intent": "Error",
109
- "solution": f"English model failed: {e}",
110
- "model_used": "English BERT"
111
- }
112
 
113
  @app.get("/")
114
- def root():
115
- ready = all([model, tokenizer, muril_pipeline])
116
- return {"status": "✅ AI LegalAid Chatbot Running", "models_ready": ready}
 
 
1
 
2
  import os
 
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ import pandas as pd
8
 
9
+ # --- Configuration ---
10
+ FINAL_MODEL_PATH = './final_bert_model_pdf'
 
11
  SOLUTIONS_DATASET_PATH = 'qa_dataset_detailed_answers.csv'
 
 
 
 
 
12
 
13
+ # --- Load Models and Data ---
 
 
14
  def load_resources():
15
  try:
 
16
  tokenizer = AutoTokenizer.from_pretrained(FINAL_MODEL_PATH)
17
  model = AutoModelForSequenceClassification.from_pretrained(FINAL_MODEL_PATH)
18
+ solutions_df = pd.read_csv(SOLUTIONS_DATASET_PATH)
19
+ solution_database = solutions_df.set_index('Intent')['Answer'].to_dict()
20
+ print(" Resources loaded successfully!")
21
+ return model, tokenizer, solution_database
 
 
 
 
 
 
22
  except Exception as e:
23
+ print(f"❌ Critical Error loading resources: {e}")
24
+ return None, None, None
25
 
26
+ model, tokenizer, solution_database = load_resources()
27
 
28
+ # --- Initialize FastAPI ---
29
+ app = FastAPI(title="Legal Aid API")
30
 
31
+ # --- API Data Models ---
32
  class QueryRequest(BaseModel):
33
  question: str
34
 
35
  class SolutionResponse(BaseModel):
36
  predicted_intent: str
37
  solution: str
 
38
 
39
+ # --- API Endpoints ---
40
  @app.post("/get-solution", response_model=SolutionResponse)
41
  def get_legal_solution(request: QueryRequest):
42
+ if not model:
43
+ return {"predicted_intent": "Error", "solution": "Model not loaded."}
44
+
45
+ inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
46
+ with torch.no_grad():
47
+ logits = model(**inputs).logits
48
+ prediction_id = torch.argmax(logits, dim=1).item()
49
+ predicted_intent = model.config.id2label[prediction_id]
50
+ solution = solution_database.get(predicted_intent, "No solution found.")
51
+
52
+ return {"predicted_intent": predicted_intent, "solution": solution}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  @app.get("/")
55
+ def read_root():
56
+ return {"status": "Legal Aid API is running."}