Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,11 +5,27 @@ from sentence_transformers import CrossEncoder
|
|
| 5 |
import re
|
| 6 |
import hashlib
|
| 7 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# ============================================================
|
| 10 |
# MODEL LOADING (ONCE)
|
| 11 |
# ============================================================
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
DEVICE = "cpu"
|
| 14 |
|
| 15 |
SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
|
|
@@ -64,27 +80,62 @@ def classify_question(question):
|
|
| 64 |
# SCHEMA GENERATION (AUTO, NO LLM)
|
| 65 |
# ============================================================
|
| 66 |
|
| 67 |
-
def
|
| 68 |
"""
|
| 69 |
-
|
| 70 |
-
Deterministic
|
| 71 |
"""
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
return schema
|
| 87 |
|
|
|
|
| 88 |
# ============================================================
|
| 89 |
# ANSWER DECOMPOSITION
|
| 90 |
# ============================================================
|
|
@@ -105,7 +156,7 @@ def evaluate_answer(answer, question, kb):
|
|
| 105 |
# --------------------
|
| 106 |
key = hash_key(kb, question)
|
| 107 |
if key not in SCHEMA_CACHE:
|
| 108 |
-
SCHEMA_CACHE[key] =
|
| 109 |
|
| 110 |
schema = SCHEMA_CACHE[key]
|
| 111 |
logs["schema"] = schema
|
|
|
|
| 5 |
import re
|
| 6 |
import hashlib
|
| 7 |
import json
|
| 8 |
+
import os
|
| 9 |
+
from openai import OpenAI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
|
| 15 |
# ============================================================
|
| 16 |
# MODEL LOADING (ONCE)
|
| 17 |
# ============================================================
|
| 18 |
|
| 19 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 20 |
+
if not OPENAI_API_KEY:
|
| 21 |
+
raise RuntimeError("OPENAI_API_KEY not found in environment")
|
| 22 |
+
|
| 23 |
+
llm_client = OpenAI(api_key=OPENAI_API_KEY)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
DEVICE = "cpu"
|
| 30 |
|
| 31 |
SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
|
|
|
|
| 80 |
# SCHEMA GENERATION (AUTO, NO LLM)
|
| 81 |
# ============================================================
|
| 82 |
|
| 83 |
+
def generate_schema_with_llm(kb, question):
|
| 84 |
"""
|
| 85 |
+
Uses ChatGPT to generate an explicit grading schema.
|
| 86 |
+
Cached. Deterministic via temperature=0.
|
| 87 |
"""
|
| 88 |
+
|
| 89 |
+
prompt = f"""
|
| 90 |
+
You are an exam answer key generator.
|
| 91 |
+
|
| 92 |
+
Knowledge Base:
|
| 93 |
+
\"\"\"
|
| 94 |
+
{kb}
|
| 95 |
+
\"\"\"
|
| 96 |
+
|
| 97 |
+
Question:
|
| 98 |
+
\"\"\"
|
| 99 |
+
{question}
|
| 100 |
+
\"\"\"
|
| 101 |
+
|
| 102 |
+
TASK:
|
| 103 |
+
Extract the expected answer as atomic facts.
|
| 104 |
+
Return STRICT JSON with this schema:
|
| 105 |
+
|
| 106 |
+
{{
|
| 107 |
+
"question_type": "FACT | DEFINITION | EXPLANATION",
|
| 108 |
+
"required_concepts": ["fact1", "fact2"],
|
| 109 |
+
"forbidden_concepts": [],
|
| 110 |
+
"allow_extra_info": true
|
| 111 |
+
}}
|
| 112 |
+
|
| 113 |
+
Rules:
|
| 114 |
+
- required_concepts must be explicit factual statements
|
| 115 |
+
- Do NOT paraphrase excessively
|
| 116 |
+
- Do NOT invent facts
|
| 117 |
+
- JSON only. No explanations.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
response = llm_client.chat.completions.create(
|
| 121 |
+
model="gpt-4o-mini",
|
| 122 |
+
messages=[
|
| 123 |
+
{"role": "system", "content": "You generate grading rubrics for exams."},
|
| 124 |
+
{"role": "user", "content": prompt}
|
| 125 |
+
],
|
| 126 |
+
temperature=0
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
content = response.choices[0].message.content.strip()
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
schema = json.loads(content)
|
| 133 |
+
except json.JSONDecodeError:
|
| 134 |
+
raise ValueError(f"LLM returned invalid JSON:\n{content}")
|
| 135 |
+
|
| 136 |
return schema
|
| 137 |
|
| 138 |
+
|
| 139 |
# ============================================================
|
| 140 |
# ANSWER DECOMPOSITION
|
| 141 |
# ============================================================
|
|
|
|
| 156 |
# --------------------
|
| 157 |
key = hash_key(kb, question)
|
| 158 |
if key not in SCHEMA_CACHE:
|
| 159 |
+
SCHEMA_CACHE[key] = generate_schema_with_llm(kb, question)
|
| 160 |
|
| 161 |
schema = SCHEMA_CACHE[key]
|
| 162 |
logs["schema"] = schema
|