roshcheeku commited on
Commit
8385a71
·
verified ·
1 Parent(s): 00c7c1d

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +26 -19
model_utils.py CHANGED
@@ -2,38 +2,45 @@ import os
2
 
3
  # Fix: Redirect Hugging Face cache to a writable folder
4
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
5
- os.environ["HF_HOME"] = "/tmp/hf_cache" # new standard from transformers v5+
6
 
7
  from transformers import pipeline
8
 
9
- # Initialize zero-shot classification pipeline
10
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
11
 
12
- # Define labels for classification
13
  labels = ["question", "option", "answer", "other"]
14
 
15
- def classify_chunk(text):
16
- result = classifier(text, labels)
17
- return result['labels'][0] # Return the top predicted label
18
-
19
  def extract_mcqs_with_model(text):
20
- # Split text into chunks, skipping empty ones
 
 
 
 
21
  chunks = [chunk.strip() for chunk in text.split("\n\n") if chunk.strip()]
22
  mcqs = []
23
  current = {"question": "", "options": [], "answer": ""}
24
 
25
- for chunk in chunks:
26
- label = classify_chunk(chunk)
27
- if label == "question":
28
- if current["question"]:
29
- mcqs.append(current)
30
- current = {"question": "", "options": [], "answer": ""}
31
- current["question"] = chunk
32
- elif label == "option":
33
- current["options"].append(chunk)
34
- elif label == "answer":
35
- current["answer"] = chunk
 
 
 
 
 
 
36
 
37
  if current["question"]:
38
  mcqs.append(current)
 
39
  return mcqs
 
2
 
3
  # Fix: Redirect Hugging Face cache to a writable folder
4
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
5
+ os.environ["HF_HOME"] = "/tmp/hf_cache" # New standard from transformers v5+
6
 
7
  from transformers import pipeline
8
 
9
+ # Load the model once at the start — this avoids reloading on every request
10
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
11
 
12
+ # Define the classification labels
13
  labels = ["question", "option", "answer", "other"]
14
 
 
 
 
 
15
  def extract_mcqs_with_model(text):
16
+ """
17
+ Extract MCQs from a given large body of text using zero-shot classification.
18
+ Optimized for large documents by batch processing.
19
+ """
20
+ # Clean and split text into meaningful chunks
21
  chunks = [chunk.strip() for chunk in text.split("\n\n") if chunk.strip()]
22
  mcqs = []
23
  current = {"question": "", "options": [], "answer": ""}
24
 
25
+ # Process chunks in batches for speed (e.g., 5 chunks at a time)
26
+ batch_size = 10
27
+ for i in range(0, len(chunks), batch_size):
28
+ batch = chunks[i:i+batch_size]
29
+ results = classifier(batch, labels)
30
+
31
+ for chunk, result in zip(batch, results):
32
+ label = result['labels'][0]
33
+ if label == "question":
34
+ if current["question"]:
35
+ mcqs.append(current)
36
+ current = {"question": "", "options": [], "answer": ""}
37
+ current["question"] = chunk
38
+ elif label == "option":
39
+ current["options"].append(chunk)
40
+ elif label == "answer":
41
+ current["answer"] = chunk
42
 
43
  if current["question"]:
44
  mcqs.append(current)
45
+
46
  return mcqs