Sp2503 commited on
Commit
4bea9cd
Β·
verified Β·
1 Parent(s): ccd41a5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +83 -87
main.py CHANGED
@@ -1,61 +1,49 @@
 
 
1
  import os
2
  import torch
3
  import pandas as pd
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
  from langdetect import detect
8
 
9
- # ============================================================
10
- # βœ… Environment setup β€” avoids permission errors on Hugging Face
11
- # ============================================================
 
 
 
12
  os.environ["HF_HOME"] = "/app/hf_cache"
13
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
14
- os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
15
- os.environ["TORCH_HOME"] = "/app/hf_cache"
16
-
17
- # ============================================================
18
- # βœ… Configuration
19
- # ============================================================
20
- FINAL_MODEL_PATH = './final_bert_model_pdf'
21
- SOLUTIONS_DATASET_PATH = 'qa_dataset_detailed_answers.csv'
22
- MURIL_MODEL_NAME = 'Sp2503/Muril-Model' # Your public HF model
23
 
24
- # ============================================================
25
- # βœ… Load models and data
26
- # ============================================================
27
  def load_resources():
28
  try:
29
  # Load English model
30
- tokenizer_en = AutoTokenizer.from_pretrained(FINAL_MODEL_PATH)
31
- model_en = AutoModelForSequenceClassification.from_pretrained(FINAL_MODEL_PATH)
32
-
33
- # Load MuRIL multilingual model (for non-English)
34
- tokenizer_muril = AutoTokenizer.from_pretrained(MURIL_MODEL_NAME)
35
- model_muril = AutoModelForSequenceClassification.from_pretrained(MURIL_MODEL_NAME)
36
-
37
- # Load Q&A dataset
38
- solutions_df = pd.read_csv(SOLUTIONS_DATASET_PATH)
39
- solution_database = solutions_df.set_index('Intent')['Answer'].to_dict()
40
-
41
- print("βœ… All models and data loaded successfully!")
42
- return model_en, tokenizer_en, model_muril, tokenizer_muril, solution_database
43
-
44
  except Exception as e:
45
  print(f"❌ Error loading models or data: {e}")
46
- return None, None, None, None, None
47
-
48
 
49
- model_en, tokenizer_en, model_muril, tokenizer_muril, solution_database = load_resources()
50
 
51
- # ============================================================
52
- # βœ… FastAPI app setup
53
- # ============================================================
54
- app = FastAPI(title="Legal Aid API")
55
 
56
- # ============================================================
57
- # βœ… Request and Response Models
58
- # ============================================================
59
  class QueryRequest(BaseModel):
60
  question: str
61
 
@@ -64,55 +52,63 @@ class SolutionResponse(BaseModel):
64
  solution: str
65
  model_used: str
66
 
67
- model_config = {
68
- "protected_namespaces": () # suppress Pydantic warning
69
- }
70
-
71
- # ============================================================
72
- # βœ… Helper: Detect if question is English
73
- # ============================================================
74
- def is_english(text: str) -> bool:
75
- try:
76
- lang = detect(text)
77
- return lang == "en"
78
- except:
79
- return True # default fallback to English
80
-
81
- # ============================================================
82
- # βœ… Main API Endpoint
83
- # ============================================================
84
  @app.post("/get-solution", response_model=SolutionResponse)
85
  def get_legal_solution(request: QueryRequest):
86
- if not model_en or not model_muril:
87
- return {"predicted_intent": "Error", "solution": "Models not loaded.", "model_used": "None"}
 
 
 
 
88
 
89
  question = request.question.strip()
90
- use_english = is_english(question)
91
-
92
- # Select model based on language
93
- model = model_en if use_english else model_muril
94
- tokenizer = tokenizer_en if use_english else tokenizer_muril
95
- model_name = "BERT-English" if use_english else "MuRIL-Multilingual"
96
-
97
- # Tokenize and predict
98
- inputs = tokenizer(question, return_tensors="pt", truncation=True, padding=True)
99
- with torch.no_grad():
100
- logits = model(**inputs).logits
101
- prediction_id = torch.argmax(logits, dim=1).item()
102
- predicted_intent = model.config.id2label.get(prediction_id, "Unknown")
103
-
104
- # Fetch solution
105
- solution = solution_database.get(predicted_intent, "No solution found in database.")
106
-
107
- return {
108
- "predicted_intent": predicted_intent,
109
- "solution": solution,
110
- "model_used": model_name
111
- }
112
-
113
- # ============================================================
114
- # βœ… Root Endpoint
115
- # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  @app.get("/")
117
- def read_root():
118
- return {"status": "βœ… Legal Aid API is running with English + MuRIL multilingual support."}
 
 
1
+ # main.py
2
+
3
  import os
4
  import torch
5
  import pandas as pd
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
9
  from langdetect import detect
10
 
11
+ # ========== CONFIG ==========
12
+ FINAL_MODEL_PATH = './final_bert_model_pdf' # Local fine-tuned English model
13
+ MURIL_MODEL_ID = 'Sp2503/Muril-Model' # Hugging Face multilingual model
14
+ SOLUTIONS_DATASET_PATH = 'qa_dataset_detailed_answers.csv'
15
+
16
+ # Fix cache permissions for Spaces
17
  os.environ["HF_HOME"] = "/app/hf_cache"
18
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
19
+ os.makedirs("/app/hf_cache", exist_ok=True)
 
 
 
 
 
 
 
 
20
 
21
+ # ========== LOAD MODELS ==========
 
 
22
  def load_resources():
23
  try:
24
  # Load English model
25
+ tokenizer = AutoTokenizer.from_pretrained(FINAL_MODEL_PATH)
26
+ model = AutoModelForSequenceClassification.from_pretrained(FINAL_MODEL_PATH)
27
+
28
+ # Load multilingual MuRIL model for non-English
29
+ muril_pipeline = pipeline("text-classification", model=MURIL_MODEL_ID)
30
+
31
+ # Load answers dataset
32
+ df = pd.read_csv(SOLUTIONS_DATASET_PATH)
33
+ solution_db = df.set_index('Intent')['Answer'].to_dict()
34
+
35
+ print("βœ… All models & data loaded successfully!")
36
+ return model, tokenizer, muril_pipeline, solution_db
 
 
37
  except Exception as e:
38
  print(f"❌ Error loading models or data: {e}")
39
+ return None, None, None, None
 
40
 
41
+ model, tokenizer, muril_pipeline, solution_db = load_resources()
42
 
43
+ # ========== FASTAPI APP ==========
44
+ app = FastAPI(title="AI LegalAid Chatbot Server")
 
 
45
 
46
+ # Request / Response Schemas
 
 
47
  class QueryRequest(BaseModel):
48
  question: str
49
 
 
52
  solution: str
53
  model_used: str
54
 
55
+ # ========== LOGIC ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @app.post("/get-solution", response_model=SolutionResponse)
57
  def get_legal_solution(request: QueryRequest):
58
+ if not model or not tokenizer:
59
+ return {
60
+ "predicted_intent": "Error",
61
+ "solution": "Model not loaded properly.",
62
+ "model_used": "None"
63
+ }
64
 
65
  question = request.question.strip()
66
+
67
+ # Detect language
68
+ try:
69
+ lang = detect(question)
70
+ except:
71
+ lang = "en"
72
+
73
+ # If not English, use MuRIL model
74
+ if lang != "en":
75
+ try:
76
+ muril_result = muril_pipeline(question)
77
+ predicted_intent = muril_result[0]['label']
78
+ solution = solution_db.get(predicted_intent, "No solution found for this intent.")
79
+ return {
80
+ "predicted_intent": predicted_intent,
81
+ "solution": solution,
82
+ "model_used": "MuRIL"
83
+ }
84
+ except Exception as e:
85
+ return {
86
+ "predicted_intent": "Error",
87
+ "solution": f"MuRIL model failed: {e}",
88
+ "model_used": "MuRIL"
89
+ }
90
+
91
+ # For English questions β†’ use fine-tuned BERT model
92
+ try:
93
+ inputs = tokenizer(question, return_tensors="pt", truncation=True, padding=True)
94
+ with torch.no_grad():
95
+ logits = model(**inputs).logits
96
+ prediction_id = torch.argmax(logits, dim=1).item()
97
+ predicted_intent = model.config.id2label[prediction_id]
98
+ solution = solution_db.get(predicted_intent, "No solution found for this intent.")
99
+ return {
100
+ "predicted_intent": predicted_intent,
101
+ "solution": solution,
102
+ "model_used": "English BERT"
103
+ }
104
+ except Exception as e:
105
+ return {
106
+ "predicted_intent": "Error",
107
+ "solution": f"English model failed: {e}",
108
+ "model_used": "English BERT"
109
+ }
110
+
111
  @app.get("/")
112
+ def root():
113
+ ready = all([model, tokenizer, muril_pipeline])
114
+ return {"status": "βœ… AI LegalAid Chatbot Running", "models_ready": ready}