Got feedback endpoint working
Browse files- backend/app/main.py +28 -11
- backend/app/problem_generator.py +0 -1
- backend/app/problem_grader.py +3 -4
- backend/tests/test_api.py +45 -0
backend/app/main.py
CHANGED
|
@@ -4,7 +4,9 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 4 |
from fastapi.responses import FileResponse
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from backend.app.problem_generator import ProblemGenerationPipeline
|
|
|
|
| 7 |
from typing import Dict, List
|
|
|
|
| 8 |
|
| 9 |
app = FastAPI()
|
| 10 |
|
|
@@ -22,11 +24,15 @@ class UrlInput(BaseModel):
|
|
| 22 |
class UserQuery(BaseModel):
|
| 23 |
user_query: str
|
| 24 |
|
| 25 |
-
|
|
|
|
| 26 |
user_query: str
|
| 27 |
problems: list[str]
|
| 28 |
user_answers: list[str]
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
@app.post("/api/crawl/")
|
| 31 |
async def crawl_documentation(input_data: UrlInput):
|
| 32 |
print(f"Received url {input_data.url}")
|
|
@@ -37,17 +43,28 @@ async def generate_problems(query: UserQuery):
|
|
| 37 |
problems = ProblemGenerationPipeline().generate_problems(query.user_query)
|
| 38 |
return {"Problems": problems}
|
| 39 |
|
| 40 |
-
@app.post("/api/feedback
|
| 41 |
-
async def
|
| 42 |
-
|
| 43 |
-
if len(feedback.problems) != len(feedback.user_answers):
|
| 44 |
raise HTTPException(status_code=400, detail="Problems and user answers must have the same length")
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Serve static files
|
| 53 |
app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
|
|
|
|
| 4 |
from fastapi.responses import FileResponse
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from backend.app.problem_generator import ProblemGenerationPipeline
|
| 7 |
+
from backend.app.problem_grader import ProblemGradingPipeline
|
| 8 |
from typing import Dict, List
|
| 9 |
+
import asyncio
|
| 10 |
|
| 11 |
app = FastAPI()
|
| 12 |
|
|
|
|
| 24 |
class UserQuery(BaseModel):
|
| 25 |
user_query: str
|
| 26 |
|
| 27 |
+
# TODO: Make this a list of {problem: str, answer: str}. Would be cleaner for data validation
|
| 28 |
+
class FeedbackRequest(BaseModel):
|
| 29 |
user_query: str
|
| 30 |
problems: list[str]
|
| 31 |
user_answers: list[str]
|
| 32 |
|
| 33 |
+
class FeedbackResponse(BaseModel):
|
| 34 |
+
feedback: List[str]
|
| 35 |
+
|
| 36 |
@app.post("/api/crawl/")
|
| 37 |
async def crawl_documentation(input_data: UrlInput):
|
| 38 |
print(f"Received url {input_data.url}")
|
|
|
|
| 43 |
problems = ProblemGenerationPipeline().generate_problems(query.user_query)
|
| 44 |
return {"Problems": problems}
|
| 45 |
|
| 46 |
+
@app.post("/api/feedback", response_model=FeedbackResponse)
|
| 47 |
+
async def get_feedback(request: FeedbackRequest):
|
| 48 |
+
if len(request.problems) != len(request.user_answers):
|
|
|
|
| 49 |
raise HTTPException(status_code=400, detail="Problems and user answers must have the same length")
|
| 50 |
+
try:
|
| 51 |
+
grader = ProblemGradingPipeline()
|
| 52 |
+
|
| 53 |
+
grading_tasks = [
|
| 54 |
+
grader.grade(
|
| 55 |
+
query=request.user_query,
|
| 56 |
+
problem=problem,
|
| 57 |
+
answer=user_answer,
|
| 58 |
+
)
|
| 59 |
+
for problem, user_answer in zip(request.problems, request.user_answers)
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
feedback_list = await asyncio.gather(*grading_tasks)
|
| 63 |
+
|
| 64 |
+
return FeedbackResponse(feedback=feedback_list)
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 68 |
|
| 69 |
# Serve static files
|
| 70 |
app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
|
backend/app/problem_generator.py
CHANGED
|
@@ -36,7 +36,6 @@ class ProblemGenerationPipeline:
|
|
| 36 |
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
|
| 37 |
self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
|
| 38 |
|
| 39 |
-
# Build the RAG chain
|
| 40 |
self.rag_chain = (
|
| 41 |
{"context": self.retriever, "query": RunnablePassthrough()}
|
| 42 |
| self.chat_prompt
|
|
|
|
| 36 |
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
|
| 37 |
self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
|
| 38 |
|
|
|
|
| 39 |
self.rag_chain = (
|
| 40 |
{"context": self.retriever, "query": RunnablePassthrough()}
|
| 41 |
| self.chat_prompt
|
backend/app/problem_grader.py
CHANGED
|
@@ -40,7 +40,6 @@ class ProblemGradingPipeline:
|
|
| 40 |
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
|
| 41 |
self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
|
| 42 |
|
| 43 |
-
# Build the RAG chain
|
| 44 |
self.rag_chain = (
|
| 45 |
{
|
| 46 |
"context": self.retriever,
|
|
@@ -53,9 +52,9 @@ class ProblemGradingPipeline:
|
|
| 53 |
| StrOutputParser()
|
| 54 |
)
|
| 55 |
|
| 56 |
-
def grade(self, query: str, problem: str, answer: str) -> str:
|
| 57 |
"""
|
| 58 |
-
|
| 59 |
|
| 60 |
Args:
|
| 61 |
query (str): The topic/context to use for grading
|
|
@@ -65,7 +64,7 @@ class ProblemGradingPipeline:
|
|
| 65 |
Returns:
|
| 66 |
str: Grading response indicating if the answer is correct and providing feedback
|
| 67 |
"""
|
| 68 |
-
return self.rag_chain.
|
| 69 |
"query": query,
|
| 70 |
"problem": problem,
|
| 71 |
"answer": answer
|
|
|
|
| 40 |
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
|
| 41 |
self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
|
| 42 |
|
|
|
|
| 43 |
self.rag_chain = (
|
| 44 |
{
|
| 45 |
"context": self.retriever,
|
|
|
|
| 52 |
| StrOutputParser()
|
| 53 |
)
|
| 54 |
|
| 55 |
+
async def grade(self, query: str, problem: str, answer: str) -> str:
|
| 56 |
"""
|
| 57 |
+
Asynchronously grade a student's answer to a problem using RAG for context-aware evaluation.
|
| 58 |
|
| 59 |
Args:
|
| 60 |
query (str): The topic/context to use for grading
|
|
|
|
| 64 |
Returns:
|
| 65 |
str: Grading response indicating if the answer is correct and providing feedback
|
| 66 |
"""
|
| 67 |
+
return await self.rag_chain.ainvoke({
|
| 68 |
"query": query,
|
| 69 |
"problem": problem,
|
| 70 |
"answer": answer
|
backend/tests/test_api.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from fastapi.testclient import TestClient
|
| 2 |
from backend.app.main import app
|
|
|
|
| 3 |
|
| 4 |
client = TestClient(app)
|
| 5 |
|
|
@@ -19,4 +20,48 @@ def test_problems_endpoint():
|
|
| 19 |
assert response.status_code == 200
|
| 20 |
assert "Problems" in response.json()
|
| 21 |
assert len(response.json()["Problems"]) == 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
|
|
|
| 1 |
from fastapi.testclient import TestClient
|
| 2 |
from backend.app.main import app
|
| 3 |
+
import pytest
|
| 4 |
|
| 5 |
client = TestClient(app)
|
| 6 |
|
|
|
|
| 20 |
assert response.status_code == 200
|
| 21 |
assert "Problems" in response.json()
|
| 22 |
assert len(response.json()["Problems"]) == 5
|
| 23 |
+
|
| 24 |
+
def test_feedback_validation_error():
|
| 25 |
+
"""Test that mismatched problems and answers lengths return 400"""
|
| 26 |
+
response = client.post(
|
| 27 |
+
"/api/feedback",
|
| 28 |
+
json={
|
| 29 |
+
"user_query": "Python lists",
|
| 30 |
+
"problems": ["What is a list?", "How do you append?"],
|
| 31 |
+
"user_answers": ["A sequence",] # Only one answer
|
| 32 |
+
}
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
assert response.status_code == 400
|
| 36 |
+
assert "same length" in response.json()["detail"]
|
| 37 |
+
|
| 38 |
+
@pytest.mark.asyncio
|
| 39 |
+
async def test_successful_feedback():
|
| 40 |
+
"""Test successful grading of multiple problems"""
|
| 41 |
+
response = client.post(
|
| 42 |
+
"/api/feedback",
|
| 43 |
+
json={
|
| 44 |
+
"user_query": "RAG",
|
| 45 |
+
"problems": [
|
| 46 |
+
"What are the two main components of a typical RAG application?",
|
| 47 |
+
"What is the purpose of the indexing component in a RAG application?"
|
| 48 |
+
],
|
| 49 |
+
"user_answers": [
|
| 50 |
+
"A list is a mutable sequence type that can store multiple items in Python",
|
| 51 |
+
"You use the append() method to add an element to the end of a list"
|
| 52 |
+
]
|
| 53 |
+
}
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
assert response.status_code == 200
|
| 57 |
+
result = response.json()
|
| 58 |
+
assert "feedback" in result
|
| 59 |
+
assert len(result["feedback"]) == 2
|
| 60 |
+
|
| 61 |
+
# Check that responses start with either "Correct" or "Incorrect"
|
| 62 |
+
for feedback in result["feedback"]:
|
| 63 |
+
assert feedback.startswith(("Correct", "Incorrect"))
|
| 64 |
+
# Check that there's an explanation after the classification
|
| 65 |
+
assert len(feedback.split(". ")) >= 2
|
| 66 |
+
|
| 67 |
|