Sp2503 commited on
Commit
c1dfe4e
Β·
verified Β·
1 Parent(s): 91bccf1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -25
main.py CHANGED
@@ -2,24 +2,26 @@ import os
2
  import pandas as pd
3
  import torch
4
  from fastapi import FastAPI, HTTPException
5
- from pydantic import BaseModel
6
  from pymongo import MongoClient
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
  from typing import Dict, List, Optional
9
 
10
- # --- Configuration using Environment Variables from Hugging Face Secrets ---
11
- MODEL_NAME = os.getenv("MODEL_NAME")
12
- MONGO_URI = os.getenv("MONGO_URI")
 
 
 
13
  DB_NAME = "legal_aid-chatbot"
14
  COLLECTION_NAME = "categories"
15
 
16
- # --- Global Resources (loaded once at startup) ---
17
  model: Optional[AutoModelForSequenceClassification] = None
18
  tokenizer: Optional[AutoTokenizer] = None
19
  intent_map: Dict[int, str] = {}
20
  collection: Optional[MongoClient] = None
21
 
22
- # --- Helper function to create the intent map ---
23
  def create_intent_map(csv_files: List[str]) -> Dict[int, str]:
24
  """Creates a consistent intent-to-ID mapping from the training CSVs."""
25
  all_intents = set()
@@ -31,58 +33,49 @@ def create_intent_map(csv_files: List[str]) -> Dict[int, str]:
31
  except FileNotFoundError as e:
32
  print(f"❌ Critical Error: CSV for intent mapping not found: {e}")
33
  return {}
34
- # Sort the intents to ensure the mapping is always the same
35
  return {i: intent for i, intent in enumerate(sorted(list(all_intents)))}
36
 
37
- # --- Application Startup Event ---
38
  app = FastAPI(title="Legal Aid Chatbot API", version="1.0.0")
39
 
40
  @app.on_event("startup")
41
  def startup_event():
42
  """Loads all necessary resources when the FastAPI application starts."""
43
  global model, tokenizer, intent_map, collection
44
-
45
  print("--- Loading resources on application startup ---")
46
 
47
  if not MONGO_URI:
48
  print("❌ Critical Error: MONGO_URI secret is not set in Hugging Face Space settings.")
49
  return
50
- if not MODEL_NAME:
51
- print("❌ Critical Error: MODEL_NAME secret is not set in Hugging Face Space settings.")
52
- return
53
-
54
  intent_map = create_intent_map(['womens_legal_questions_20k.csv', 'legal_aid_chatbot_dataset_20k.csv'])
55
  if not intent_map:
56
  print("❌ Could not create intent map. API will not function correctly.")
57
  return
58
 
59
  try:
60
- # Use a writable cache directory for Hugging Face Spaces
61
  cache_dir = "/tmp"
62
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=cache_dir)
63
-
64
- # Explicitly provide the number of labels for the classifier
65
  num_labels = len(intent_map)
66
  model = AutoModelForSequenceClassification.from_pretrained(
67
- MODEL_NAME,
68
- num_labels=num_labels,
69
  cache_dir=cache_dir
70
  )
71
  print(f"βœ… Model '{MODEL_NAME}' and tokenizer loaded successfully.")
72
  except Exception as e:
73
  print(f"❌ Critical Error loading Hugging Face model: {e}")
74
- model = None
75
 
76
  try:
77
  client = MongoClient(MONGO_URI)
78
  collection = client[DB_NAME][COLLECTION_NAME]
79
- client.server_info() # Test connection
80
  print("πŸš€ Successfully connected to MongoDB.")
81
  except Exception as e:
82
  print(f"❌ Critical Error connecting to MongoDB: {e}")
83
  collection = None
84
 
85
- # --- API Data Models ---
86
  class QueryRequest(BaseModel):
87
  question: str
88
 
@@ -90,22 +83,19 @@ class SolutionResponse(BaseModel):
90
  predicted_intent: str
91
  solution: str
92
 
93
- # --- API Endpoint ---
94
  @app.post("/get-solution", response_model=SolutionResponse)
95
  def get_legal_solution(request: QueryRequest):
96
- """Receives a question, predicts intent, and retrieves the solution from MongoDB."""
97
  if not all([model, tokenizer, collection]):
98
  raise HTTPException(status_code=503, detail="Server resources are not ready. Check startup logs for errors.")
99
 
100
  inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
101
  with torch.no_grad():
102
  logits = model(**inputs).logits
103
-
104
  prediction_id = torch.argmax(logits, dim=1).item()
105
  predicted_intent = intent_map.get(prediction_id, "Unknown Intent")
106
 
107
  document = collection.find_one({"intent": predicted_intent})
108
-
109
  solution = document["answer"] if document and "answer" in document else "No specific solution was found for this topic."
110
 
111
  return SolutionResponse(predicted_intent=predicted_intent, solution=solution)
 
2
  import pandas as pd
3
  import torch
4
  from fastapi import FastAPI, HTTPException
5
+ from pantic import BaseModel
6
  from pymongo import MongoClient
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
  from typing import Dict, List, Optional
9
 
10
+ # --- Configuration ---
11
+ # ** FIX: Hardcoding the correct model name from the URL to bypass any issues with secrets. **
12
+ MODEL_NAME = "Sp2503/Bertmodel"
13
+
14
+ # The MongoDB URI is loaded from Hugging Face Space secrets for security
15
+ MONGO_URI = os.getenv("MONGO_URI")
16
  DB_NAME = "legal_aid-chatbot"
17
  COLLECTION_NAME = "categories"
18
 
19
+ # --- Global Resources ---
20
  model: Optional[AutoModelForSequenceClassification] = None
21
  tokenizer: Optional[AutoTokenizer] = None
22
  intent_map: Dict[int, str] = {}
23
  collection: Optional[MongoClient] = None
24
 
 
25
  def create_intent_map(csv_files: List[str]) -> Dict[int, str]:
26
  """Creates a consistent intent-to-ID mapping from the training CSVs."""
27
  all_intents = set()
 
33
  except FileNotFoundError as e:
34
  print(f"❌ Critical Error: CSV for intent mapping not found: {e}")
35
  return {}
 
36
  return {i: intent for i, intent in enumerate(sorted(list(all_intents)))}
37
 
 
38
  app = FastAPI(title="Legal Aid Chatbot API", version="1.0.0")
39
 
40
  @app.on_event("startup")
41
  def startup_event():
42
  """Loads all necessary resources when the FastAPI application starts."""
43
  global model, tokenizer, intent_map, collection
44
+
45
  print("--- Loading resources on application startup ---")
46
 
47
  if not MONGO_URI:
48
  print("❌ Critical Error: MONGO_URI secret is not set in Hugging Face Space settings.")
49
  return
50
+
 
 
 
51
  intent_map = create_intent_map(['womens_legal_questions_20k.csv', 'legal_aid_chatbot_dataset_20k.csv'])
52
  if not intent_map:
53
  print("❌ Could not create intent map. API will not function correctly.")
54
  return
55
 
56
  try:
 
57
  cache_dir = "/tmp"
58
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=cache_dir)
 
 
59
  num_labels = len(intent_map)
60
  model = AutoModelForSequenceClassification.from_pretrained(
61
+ MODEL_NAME,
62
+ num_labels=num_labels,
63
  cache_dir=cache_dir
64
  )
65
  print(f"βœ… Model '{MODEL_NAME}' and tokenizer loaded successfully.")
66
  except Exception as e:
67
  print(f"❌ Critical Error loading Hugging Face model: {e}")
68
+ model = None
69
 
70
  try:
71
  client = MongoClient(MONGO_URI)
72
  collection = client[DB_NAME][COLLECTION_NAME]
73
+ client.server_info()
74
  print("πŸš€ Successfully connected to MongoDB.")
75
  except Exception as e:
76
  print(f"❌ Critical Error connecting to MongoDB: {e}")
77
  collection = None
78
 
 
79
  class QueryRequest(BaseModel):
80
  question: str
81
 
 
83
  predicted_intent: str
84
  solution: str
85
 
 
86
  @app.post("/get-solution", response_model=SolutionResponse)
87
  def get_legal_solution(request: QueryRequest):
 
88
  if not all([model, tokenizer, collection]):
89
  raise HTTPException(status_code=503, detail="Server resources are not ready. Check startup logs for errors.")
90
 
91
  inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
92
  with torch.no_grad():
93
  logits = model(**inputs).logits
94
+
95
  prediction_id = torch.argmax(logits, dim=1).item()
96
  predicted_intent = intent_map.get(prediction_id, "Unknown Intent")
97
 
98
  document = collection.find_one({"intent": predicted_intent})
 
99
  solution = document["answer"] if document and "answer" in document else "No specific solution was found for this topic."
100
 
101
  return SolutionResponse(predicted_intent=predicted_intent, solution=solution)