Sp2503 commited on
Commit
91bccf1
·
verified ·
1 Parent(s): 92c1ad8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -10
main.py CHANGED
@@ -8,8 +8,8 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
  from typing import Dict, List, Optional
9
 
10
  # --- Configuration using Environment Variables from Hugging Face Secrets ---
11
- MODEL_NAME = os.getenv("MODEL_NAME")
12
- MONGO_URI = os.getenv("MONGO_URI")
13
  DB_NAME = "legal_aid-chatbot"
14
  COLLECTION_NAME = "categories"
15
 
@@ -41,7 +41,7 @@ app = FastAPI(title="Legal Aid Chatbot API", version="1.0.0")
41
  def startup_event():
42
  """Loads all necessary resources when the FastAPI application starts."""
43
  global model, tokenizer, intent_map, collection
44
-
45
  print("--- Loading resources on application startup ---")
46
 
47
  if not MONGO_URI:
@@ -50,7 +50,7 @@ def startup_event():
50
  if not MODEL_NAME:
51
  print("❌ Critical Error: MODEL_NAME secret is not set in Hugging Face Space settings.")
52
  return
53
-
54
  intent_map = create_intent_map(['womens_legal_questions_20k.csv', 'legal_aid_chatbot_dataset_20k.csv'])
55
  if not intent_map:
56
  print("❌ Could not create intent map. API will not function correctly.")
@@ -60,18 +60,18 @@ def startup_event():
60
  # Use a writable cache directory for Hugging Face Spaces
61
  cache_dir = "/tmp"
62
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=cache_dir)
63
-
64
  # Explicitly provide the number of labels for the classifier
65
  num_labels = len(intent_map)
66
  model = AutoModelForSequenceClassification.from_pretrained(
67
- MODEL_NAME,
68
- num_labels=num_labels,
69
  cache_dir=cache_dir
70
  )
71
  print(f"✅ Model '{MODEL_NAME}' and tokenizer loaded successfully.")
72
  except Exception as e:
73
  print(f"❌ Critical Error loading Hugging Face model: {e}")
74
- model = None
75
 
76
  try:
77
  client = MongoClient(MONGO_URI)
@@ -100,12 +100,12 @@ def get_legal_solution(request: QueryRequest):
100
  inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
101
  with torch.no_grad():
102
  logits = model(**inputs).logits
103
-
104
  prediction_id = torch.argmax(logits, dim=1).item()
105
  predicted_intent = intent_map.get(prediction_id, "Unknown Intent")
106
 
107
  document = collection.find_one({"intent": predicted_intent})
108
-
109
  solution = document["answer"] if document and "answer" in document else "No specific solution was found for this topic."
110
 
111
  return SolutionResponse(predicted_intent=predicted_intent, solution=solution)
 
8
  from typing import Dict, List, Optional
9
 
10
  # --- Configuration using Environment Variables from Hugging Face Secrets ---
11
+ MODEL_NAME = os.getenv("MODEL_NAME")
12
+ MONGO_URI = os.getenv("MONGO_URI")
13
  DB_NAME = "legal_aid-chatbot"
14
  COLLECTION_NAME = "categories"
15
 
 
41
  def startup_event():
42
  """Loads all necessary resources when the FastAPI application starts."""
43
  global model, tokenizer, intent_map, collection
44
+
45
  print("--- Loading resources on application startup ---")
46
 
47
  if not MONGO_URI:
 
50
  if not MODEL_NAME:
51
  print("❌ Critical Error: MODEL_NAME secret is not set in Hugging Face Space settings.")
52
  return
53
+
54
  intent_map = create_intent_map(['womens_legal_questions_20k.csv', 'legal_aid_chatbot_dataset_20k.csv'])
55
  if not intent_map:
56
  print("❌ Could not create intent map. API will not function correctly.")
 
60
  # Use a writable cache directory for Hugging Face Spaces
61
  cache_dir = "/tmp"
62
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=cache_dir)
63
+
64
  # Explicitly provide the number of labels for the classifier
65
  num_labels = len(intent_map)
66
  model = AutoModelForSequenceClassification.from_pretrained(
67
+ MODEL_NAME,
68
+ num_labels=num_labels,
69
  cache_dir=cache_dir
70
  )
71
  print(f"✅ Model '{MODEL_NAME}' and tokenizer loaded successfully.")
72
  except Exception as e:
73
  print(f"❌ Critical Error loading Hugging Face model: {e}")
74
+ model = None
75
 
76
  try:
77
  client = MongoClient(MONGO_URI)
 
100
  inputs = tokenizer(request.question, return_tensors="pt", truncation=True, padding=True)
101
  with torch.no_grad():
102
  logits = model(**inputs).logits
103
+
104
  prediction_id = torch.argmax(logits, dim=1).item()
105
  predicted_intent = intent_map.get(prediction_id, "Unknown Intent")
106
 
107
  document = collection.find_one({"intent": predicted_intent})
108
+
109
  solution = document["answer"] if document and "answer" in document else "No specific solution was found for this topic."
110
 
111
  return SolutionResponse(predicted_intent=predicted_intent, solution=solution)