roshcheeku commited on
Commit
6e691b7
·
verified ·
1 Parent(s): dbd15ac

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +8 -3
model_utils.py CHANGED
@@ -1,18 +1,23 @@
1
  import os
2
- os.environ["TRANSFORMERS_CACHE"] = "./cache"
3
- os.environ["HF_HOME"] = "./cache" # also set HF_HOME to cache in same dir
 
 
4
 
5
  from transformers import pipeline
6
 
 
7
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
8
 
 
9
  labels = ["question", "option", "answer", "other"]
10
 
11
  def classify_chunk(text):
12
  result = classifier(text, labels)
13
- return result['labels'][0] # Most likely label
14
 
15
  def extract_mcqs_with_model(text):
 
16
  chunks = [chunk.strip() for chunk in text.split("\n\n") if chunk.strip()]
17
  mcqs = []
18
  current = {"question": "", "options": [], "answer": ""}
 
1
  import os
2
+
3
+ # Fix: Redirect Hugging Face cache to a writable folder
4
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
5
+ os.environ["HF_HOME"] = "/tmp/hf_cache" # new standard from transformers v5+
6
 
7
  from transformers import pipeline
8
 
9
+ # Initialize zero-shot classification pipeline
10
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
11
 
12
+ # Define labels for classification
13
  labels = ["question", "option", "answer", "other"]
14
 
15
  def classify_chunk(text):
16
  result = classifier(text, labels)
17
+ return result['labels'][0] # Return the top predicted label
18
 
19
  def extract_mcqs_with_model(text):
20
+ # Split text into chunks, skipping empty ones
21
  chunks = [chunk.strip() for chunk in text.split("\n\n") if chunk.strip()]
22
  mcqs = []
23
  current = {"question": "", "options": [], "answer": ""}