Sp2503 commited on
Commit
cb5b1d2
Β·
verified Β·
1 Parent(s): ddffd31

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -82
main.py CHANGED
@@ -1,102 +1,57 @@
 
 
1
  import os
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
- from pymongo import MongoClient
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
- from sentence_transformers import SentenceTransformer
8
- from typing import List, Optional
9
 
10
  # --- Configuration ---
11
- MONGO_URI = os.getenv("MONGO_URI", "mongodb+srv://saisunil22ecs:9m2ajd0GxVn43Fbu@majorproject.g0g1as0.mongodb.net/?retryWrites=true&w=majority&appName=MajorProject")
12
- DB_NAME = os.getenv("MONGO_DB", "legal_chatbot_db")
13
- COLLECTION_NAME = os.getenv("MONGO_COLLECTION", "datasets")
14
- MODEL_PATH = os.getenv("MODEL_PATH", "./final_bert_model_pdf")
15
- EMBED_MODEL = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
16
 
17
- # --- Resource Loading ---
18
  def load_resources():
19
  try:
20
- print("πŸ”„ Loading intent classification model...")
21
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
22
- intent_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
23
- print("βœ… Intent model loaded.")
24
-
25
- print("πŸ”„ Loading embedding model for vector search...")
26
-
27
- # --- THIS IS THE FIX ---
28
- # We specify a local cache directory where the app has write permissions.
29
- cache_dir = "./model_cache"
30
- embedding_model = SentenceTransformer(EMBED_MODEL, cache_folder=cache_dir)
31
- # ---
32
-
33
- print("βœ… Embedding model loaded.")
34
-
35
- print("πŸ”„ Connecting to MongoDB Atlas...")
36
- client = MongoClient(MONGO_URI)
37
- db = client[DB_NAME]
38
- collection = db[COLLECTION_NAME]
39
- db.command('ping') # Verify connection
40
- print("βœ… MongoDB connection successful!")
41
-
42
- return tokenizer, intent_model, embedding_model, collection
43
  except Exception as e:
44
- print(f"❌ Critical Error during startup: {e}")
45
- return None, None, None, None
46
 
47
- tokenizer, intent_model, embedding_model, collection = load_resources()
48
- # --- FastAPI App ---
49
- app = FastAPI(title="Legal Aid Chatbot API")
 
50
 
51
  # --- API Data Models ---
52
- class ChatRequest(BaseModel):
53
- query: str
54
- top_k: Optional[int] = 3
55
 
56
- class ChatResponse(BaseModel):
57
- intent: str
58
- answers: List[dict]
59
 
60
  # --- API Endpoints ---
61
- @app.get("/")
62
- def health_check():
63
- """A simple endpoint to check if the API is running."""
64
- return {"status": "ok", "resources_loaded": all([tokenizer, intent_model, embedding_model, collection])}
65
-
66
- @app.post("/chat", response_model=ChatResponse)
67
- def chat(req: ChatRequest):
68
- """
69
- Main chat endpoint that performs intent classification and vector search.
70
- """
71
- if not all([tokenizer, intent_model, embedding_model, collection]):
72
- return {"intent": "Error", "answers": [{"answer": "Server is not ready. Resources could not be loaded."}]}
73
 
74
- # Step 1: Intent Classification
75
- inputs = tokenizer(req.query, return_tensors="pt", truncation=True)
76
  with torch.no_grad():
77
- logits = intent_model(**inputs).logits
78
- pred_id = torch.argmax(logits, dim=1).item()
79
- intent = intent_model.config.id2label[pred_id]
80
-
81
- # Step 2: Vector Search in MongoDB to find the most relevant documents
82
- query_embedding = embedding_model.encode(req.query, normalize_embeddings=True).tolist()
83
- pipeline = [
84
- {
85
- '$vectorSearch': {
86
- 'index': 'kb_vector_index', # Ensure this index name matches your MongoDB Atlas index
87
- 'path': 'embedding',
88
- 'queryVector': query_embedding,
89
- 'numCandidates': 100,
90
- 'limit': req.top_k
91
- }
92
- },
93
- {'$project': {'_id': 0, 'answer': '$Answer', 'question': '$Question', 'intent': '$Intent', 'score': {'$meta': 'vectorSearchScore'}}}
94
- ]
95
-
96
- try:
97
- results = list(collection.aggregate(pipeline))
98
- except Exception as e:
99
- print(f"Error during vector search: {e}")
100
- return {"intent": intent, "answers": [{"answer": "Could not retrieve documents from the knowledge base."}]}
101
 
102
- return {"intent": intent, "answers": results}
 
 
 
 
 
1
+ # main.py
2
+
3
  import os
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
 
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ import pandas as pd
 
9
 
10
  # --- Configuration ---
11
+ FINAL_MODEL_PATH = './final_bert_model_pdf'
12
+ SOLUTIONS_DATASET_PATH = 'qa_dataset_detailed_answers.csv'
 
 
 
13
 
14
+ # --- Load Models and Data ---
15
  def load_resources():
16
  try:
17
+ tokenizer = AutoTokenizer.from_pretrained(FINAL_MODEL_PATH)
18
+ model = AutoModelForSequenceClassification.from_pretrained(FINAL_MODEL_PATH)
19
+ solutions_df = pd.read_csv(SOLUTIONS_DATASET_PATH)
20
+ solution_database = solutions_df.set_index('Intent')['Answer'].to_dict()
21
+ print("βœ… Resources loaded successfully!")
22
+ return model, tokenizer, solution_database
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  except Exception as e:
24
+ print(f"❌ Critical Error loading resources: {e}")
25
+ return None, None, None
26
 
27
+ model, tokenizer, solution_database = load_resources()
28
+
29
+ # --- Initialize FastAPI ---
30
+ app = FastAPI(title="Legal Aid API")
31
 
32
  # --- API Data Models ---
33
+ class QueryRequest(BaseModel):
34
+ question: str
 
35
 
36
+ class SolutionResponse(BaseModel):
37
+ predicted_intent: str
38
+ solution: str
39
 
40
  # --- API Endpoints ---
41
+ @app.post("/get-solution", response_model=SolutionResponse)
42
+ def get_legal_solution(request: QueryRequest):
43
+ if not model:
44
+ return {"predicted_intent": "Error", "solution": "Model not loaded."}
 
 
 
 
 
 
 
 
45
 
46
+ inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
 
47
  with torch.no_grad():
48
+ logits = model(**inputs).logits
49
+ prediction_id = torch.argmax(logits, dim=1).item()
50
+ predicted_intent = model.config.id2label[prediction_id]
51
+ solution = solution_database.get(predicted_intent, "No solution found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ return {"predicted_intent": predicted_intent, "solution": solution}
54
+
55
+ @app.get("/")
56
+ def read_root():
57
+ return {"status": "Legal Aid API is running."}