heerjtdev commited on
Commit
32a28cd
·
verified ·
1 Parent(s): 1b4ce30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -18
app.py CHANGED
@@ -5,11 +5,27 @@ from sentence_transformers import CrossEncoder
5
  import re
6
  import hashlib
7
  import json
 
 
 
 
 
 
8
 
9
  # ============================================================
10
  # MODEL LOADING (ONCE)
11
  # ============================================================
12
 
 
 
 
 
 
 
 
 
 
 
13
  DEVICE = "cpu"
14
 
15
  SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
@@ -64,27 +80,62 @@ def classify_question(question):
64
  # SCHEMA GENERATION (AUTO, NO LLM)
65
  # ============================================================
66
 
67
- def generate_schema(kb, question):
68
  """
69
- Auto-generates a grading schema directly from KB.
70
- Deterministic and HF-safe.
71
  """
72
- sentences = split_sentences(kb)
73
- q_type = classify_question(question)
74
-
75
- # Find most relevant sentence
76
- scores = sim_model.predict([(s, question) for s in sentences])
77
- best_idx = int(scores.argmax())
78
- best_sentence = sentences[best_idx]
79
-
80
- schema = {
81
- "question_type": q_type,
82
- "required_concepts": [best_sentence],
83
- "forbidden_concepts": [],
84
- "allow_extra_info": True
85
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  return schema
87
 
 
88
  # ============================================================
89
  # ANSWER DECOMPOSITION
90
  # ============================================================
@@ -105,7 +156,7 @@ def evaluate_answer(answer, question, kb):
105
  # --------------------
106
  key = hash_key(kb, question)
107
  if key not in SCHEMA_CACHE:
108
- SCHEMA_CACHE[key] = generate_schema(kb, question)
109
 
110
  schema = SCHEMA_CACHE[key]
111
  logs["schema"] = schema
 
5
  import re
6
  import hashlib
7
  import json
8
+ import os
9
+ from openai import OpenAI
10
+
11
+
12
+
13
+
14
 
15
  # ============================================================
16
  # MODEL LOADING (ONCE)
17
  # ============================================================
18
 
19
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
+ if not OPENAI_API_KEY:
21
+ raise RuntimeError("OPENAI_API_KEY not found in environment")
22
+
23
+ llm_client = OpenAI(api_key=OPENAI_API_KEY)
24
+
25
+
26
+
27
+
28
+
29
  DEVICE = "cpu"
30
 
31
  SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
 
80
  # SCHEMA GENERATION (AUTO, NO LLM)
81
  # ============================================================
82
 
83
+ def generate_schema_with_llm(kb, question):
84
  """
85
+ Uses ChatGPT to generate an explicit grading schema.
86
+ Cached. Deterministic via temperature=0.
87
  """
88
+
89
+ prompt = f"""
90
+ You are an exam answer key generator.
91
+
92
+ Knowledge Base:
93
+ \"\"\"
94
+ {kb}
95
+ \"\"\"
96
+
97
+ Question:
98
+ \"\"\"
99
+ {question}
100
+ \"\"\"
101
+
102
+ TASK:
103
+ Extract the expected answer as atomic facts.
104
+ Return STRICT JSON with this schema:
105
+
106
+ {{
107
+ "question_type": "FACT | DEFINITION | EXPLANATION",
108
+ "required_concepts": ["fact1", "fact2"],
109
+ "forbidden_concepts": [],
110
+ "allow_extra_info": true
111
+ }}
112
+
113
+ Rules:
114
+ - required_concepts must be explicit factual statements
115
+ - Do NOT paraphrase excessively
116
+ - Do NOT invent facts
117
+ - JSON only. No explanations.
118
+ """
119
+
120
+ response = llm_client.chat.completions.create(
121
+ model="gpt-4o-mini",
122
+ messages=[
123
+ {"role": "system", "content": "You generate grading rubrics for exams."},
124
+ {"role": "user", "content": prompt}
125
+ ],
126
+ temperature=0
127
+ )
128
+
129
+ content = response.choices[0].message.content.strip()
130
+
131
+ try:
132
+ schema = json.loads(content)
133
+ except json.JSONDecodeError:
134
+ raise ValueError(f"LLM returned invalid JSON:\n{content}")
135
+
136
  return schema
137
 
138
+
139
  # ============================================================
140
  # ANSWER DECOMPOSITION
141
  # ============================================================
 
156
  # --------------------
157
  key = hash_key(kb, question)
158
  if key not in SCHEMA_CACHE:
159
+ SCHEMA_CACHE[key] = generate_schema_with_llm(kb, question)
160
 
161
  schema = SCHEMA_CACHE[key]
162
  logs["schema"] = schema