Spaces:
Running
Running
File size: 7,839 Bytes
e32f3cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import os
import json
from typing import Any, Dict, Optional
from flask import Flask, request, jsonify
from openai import OpenAI
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
# ----------------------------
# Configuration
# ----------------------------
CHROMA_DIR = os.getenv("CHROMA_DIR", "./chroma_db")
EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "thenlper/gte-large")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-nano")
# Hugging Face Spaces: put your key in "Settings -> Secrets" as OPENAI_API_KEY
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise RuntimeError(
"Missing OPENAI_API_KEY environment variable. "
"In Hugging Face Spaces, add it in Settings -> Secrets."
)
client = OpenAI(api_key=OPENAI_API_KEY)
# ----------------------------
# Prompt + JSON schema (from notebook)
# ----------------------------
DEV_PROMPT = """
You are a Texas Grade 5 Mathematics tutor for kids, and you also support parents and teachers.
Your tone must be kid-safe, friendly, clear, and encouraging. Keep explanations simple, accurate, and non-judgmental.
CRITICAL OUTPUT RULES:
- Output MUST be valid JSON only (no markdown, no code fences, no extra text).
- Use double quotes for all JSON keys/strings.
- Do not include trailing commas.
- Keep responses safe for kids (no unsafe, hateful, sexual, violent, or scary content).
SCOPE RULE:
- Only answer mathematics (Grade 5 level preferred; you may briefly define advanced terms if asked).
- If the user asks anything not related to mathematics, respond ONLY with:
{"type":"Refusal","message":"Sorry, I can't answer questions other than mathematics."}
OUTPUT TYPES:
1) Concept explanation:
{
"type": "Concept",
"message": "Short kid-friendly explanation..."
}
2) Practice questions (MCQ):
- Create up to 5 questions maximum. If user requests more than 5, generate only 5.
- Each question MUST have exactly 4 answer choices.
- The correct answer MUST be one of the 4 choices.
- Provide the correct answer explicitly using "CorrectOption" (A/B/C/D) and "CorrectAnswer" (exact matching text from Answers).
- Keep math appropriate and compute accurately. Avoid trick questions.
{
"type": "Questions",
"message": [
{
"Q1": "Question text",
"Answers": { "A": "Option 1", "B": "Option 2", "C": "Option 3", "D": "Option 4" },
"CorrectOption": "B",
"CorrectAnswer": "Option 2"
}
]
}
ACCURACY / ANTI-HALLUCINATION RULES:
- Do the math carefully. Ensure only one correct option unless the user explicitly asks for multiple correct answers.
- If you detect ambiguity (missing numbers/units), ask ONE clarifying question using:
{"type":"Concept","message":"I need one detail to answer: ..."}
(Keep it math-only and kid-safe.)
STYLE GUIDELINES:
- Use simple words and short sentences for kids.
- For parents/teachers, add a brief note on how to support learning (1–2 sentences).
- Avoid unrelated topics, brand names, or personal data requests.
""".strip()
JSON_SCHEMA: Dict[str, Any] = {
"type": "object",
"additionalProperties": False,
"required": ["type", "message"],
"properties": {
"type": {"type": "string", "enum": ["Concept", "Questions", "Refusal"]},
"message": {
"anyOf": [
{"type": "string"},
{
"type": "array",
"maxItems": 5,
"items": {
"type": "object",
"additionalProperties": False,
"required": ["Q1", "Answers", "CorrectOption", "CorrectAnswer"],
"properties": {
"Q1": {"type": "string"},
"Answers": {
"type": "object",
"additionalProperties": False,
"required": ["A", "B", "C", "D"],
"properties": {
"A": {"type": "string"},
"B": {"type": "string"},
"C": {"type": "string"},
"D": {"type": "string"},
},
},
"CorrectOption": {"type": "string", "enum": ["A", "B", "C", "D"]},
"CorrectAnswer": {"type": "string"},
},
},
},
]
},
},
}
# ----------------------------
# Vector DB + Retriever (loaded once at startup)
# ----------------------------
embedding_model = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL_NAME)
vectorstore = Chroma(
persist_directory=CHROMA_DIR,
embedding_function=embedding_model,
)
retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 3},
)
# ----------------------------
# RAG helpers
# ----------------------------
def generate_context_from_input(user_input: str) -> str:
"""Retrieve relevant chunks from the vector store and return as a single context string."""
rel_docs = retriever.invoke(user_input)
context_list = [d.page_content for d in rel_docs]
return ". ".join(context_list)
def get_llm_response(user_input: str, context: str = "") -> str:
"""Call OpenAI Chat/Responses API and return the model output text (JSON string)."""
messages = [{"role": "developer", "content": DEV_PROMPT}]
if context and context.strip():
messages.append(
{
"role": "developer",
"content": (
"Use the following CONTEXT only if it is relevant to the user's request. "
"Do not invent facts that are not in the context.\n"
"BEGIN_CONTEXT\n"
f"{context}\n"
"END_CONTEXT"
),
}
)
messages.append({"role": "user", "content": user_input})
resp = client.responses.create(
model=OPENAI_MODEL,
input=messages,
temperature=0.2,
max_output_tokens=800,
text={
"format": {
"type": "json_schema",
"name": "grade5_math_response",
"strict": True,
"schema": JSON_SCHEMA,
}
},
)
return resp.output_text
def generate_response_from_rag(user_input: str) -> Dict[str, Any]:
context = generate_context_from_input(user_input)
raw = get_llm_response(user_input=user_input, context=context)
# Ensure we return valid JSON to the client even if model output is slightly off.
try:
return json.loads(raw)
except Exception:
return {"type": "Concept", "message": raw}
# ----------------------------
# Flask API
# ----------------------------
app = Flask(__name__)
@app.get("/")
def health():
return jsonify({"status": "ok"})
@app.post("/MathQuestion")
def math_question():
payload = request.get_json(silent=True) or {}
query = payload.get("Query") or payload.get("query")
if not query or not isinstance(query, str):
return jsonify({"error": 'Missing required field "Query" (string).'}), 400
try:
result = generate_response_from_rag(query.strip())
return jsonify(result)
except Exception as e:
# Avoid leaking secrets; return safe error.
return jsonify({"error": "Server error while generating response.", "details": str(e)}), 500
if __name__ == "__main__":
# Hugging Face Spaces (Docker) expects port 7860
port = int(os.getenv("PORT", "7860"))
app.run(host="0.0.0.0", port=port)
|