EXAM_RAG_API / generation /ExamAnswer.py
MinaNasser's picture
1st
1bc3f18
import logging
from datetime import datetime
from typing import List, Dict, Any
from celery import shared_task
import json
import re
import httpx
from generation.answer_models import (ExamSubmission,ExamResult,StudentAnswer,GradedAnswer,QuestionType)
from indexing.indexingController import IndexingController
from stores.llm.LLMProviderFactory import LLMProviderFactory
from config import get_settings
def calculate_grade(percentage: float) -> str:
if percentage >= 90:
return "A"
elif percentage >= 80:
return "B"
elif percentage >= 70:
return "C"
elif percentage >= 60:
return "D"
else:
return "F"
logger = logging.getLogger(__name__)
class ExamGradingService:
def __init__(self, use_ai_for_essays: bool = True):
self.use_ai_for_essays = use_ai_for_essays
config = get_settings()
factory = LLMProviderFactory(config)
provider = factory.create(config.GENERATION_BACKEND)
provider.set_generation_model(config.GENERATION_MODEL_ID)
self.llm = provider
self.semantic_threshold = 0.65
self.high_confidence = 0.85
def grade_submission(self, submission: ExamSubmission) -> ExamResult:
graded_answers: List[GradedAnswer] = []
total_score = 0
max_total_score = 0
for ans in submission.answers:
correct_answer = None
if ans.metadata:
correct_answer = ans.metadata.get("correct_answer")
graded = self.grade_answer(ans, correct_answer,submission.course_id)
graded_answers.append(graded)
total_score += graded.score
max_total_score += graded.max_score
percentage = (total_score / max_total_score) * 100 if max_total_score else 0
grade = calculate_grade(percentage)
return ExamResult(
exam_id=submission.exam_id,
student_id=submission.student_id,
student_name=submission.student_name,
graded_answers=graded_answers,
total_score=total_score,
max_total_score=max_total_score,
percentage=percentage,
grade=grade,
feedback_summary="RAG based grading using LLM evaluation",
submission_time=submission.submission_time,
graded_time=datetime.utcnow().isoformat()
)
def grade_answer(self, answer: StudentAnswer, correct_answer: Any, course) -> GradedAnswer:
if answer.question_type in [QuestionType.MULTIPLE_CHOICE,QuestionType.TRUE_FALSE]:
student_str = str(answer.student_response).strip().lower()
if answer.question_type == QuestionType.TRUE_FALSE:
if isinstance(correct_answer, bool):
correct_bool = correct_answer
elif isinstance(correct_answer, str):
correct_bool = correct_answer.lower() in ['true', 't', '1', 'yes', 'True']
else:
correct_bool = bool(correct_answer)
student_bool = student_str in ['true', 't', '1', 'yes']
is_correct = student_bool == correct_bool
score = answer.max_score if is_correct else 0
feedback = "Exact match grading"
else: # multiple_choice
correct_str = str(correct_answer).strip().lower() if correct_answer else ""
is_correct = student_str == correct_str
score = answer.max_score if is_correct else 0
feedback = "Exact match grading"
else:
if self.use_ai_for_essays and correct_answer:
score, feedback = self.ai_semantic_grade(
answer.question_text,
answer.student_response,
correct_answer,
answer.max_score,
course=course
)
is_correct = score >= (answer.max_score * self.semantic_threshold)
else:
similarity = self.simple_similarity(
answer.student_response,
correct_answer
)
score = similarity * answer.max_score
is_correct = similarity >= self.semantic_threshold
feedback = f"Similarity score {similarity:.2f}"
return GradedAnswer(
question_no=answer.question_no,
question_type=answer.question_type,
question_text=answer.question_text,
student_response=answer.student_response,
correct_answer=correct_answer,
score=score,
max_score=answer.max_score,
feedback=feedback,
is_correct=is_correct
)
def simple_similarity(self, student: str, correct: str) -> float:
if not student or not correct:
return 0
student_words = set(student.lower().split())
correct_words = set(correct.lower().split())
intersection = student_words.intersection(correct_words)
union = student_words.union(correct_words)
return len(intersection) / len(union)
def retrieve_context(self, question: str, course:str):
"""
Retrieve relevant context from Qdrant for a given question filtered by course
Args: question: The question text to embed and search for // course: Optional course filter
Returns: String containing concatenated context from top 3 chunks
"""
try:
controller = IndexingController()
embedding = controller.embedder.embed_text(question)
# Build metadata filters course
filters = []
if course:
filters.append({
"field": "course",
"op": "eq",
"value": course,
"clause": "must"
})
# Query Qdrant with filters
results = controller.vector_store.query_qdrant(embedding=embedding,filters=filters,top_k=5)
context = "\n".join(r["content"] for r in results if r.get("content"))
logger.info(f"Retrieved {len(results)} chunks for question (filtered by course={course})")
return context
except Exception as e:
logger.error(f"Context retrieval failed: {e}")
return ""
def build_prompt(self, question, student_answer, correct_answer, context):
return f"""
You are an academic exam grader.
Question:
{question}
Correct Answer:
{correct_answer}
Reference Material:
{context}
Student Answer:
{student_answer}
Evaluate the student answer using semantic similarity.
You may slightly use your knowledge if correct answer not in Reference Material.
Return JSON only:
{{
"score": number between 0 and 1,
"feedback": short explanation
}}
"""
def parse_llm_output(self, text: str):
try:
if isinstance(text, dict):
if 'response' in text:
text = text['response']
else:
text = str(text)
elif hasattr(text, 'content'):
text = text.content
elif hasattr(text, 'text'):
text = text.text
text = str(text).strip()
if not text:
return 0, "Empty response from LLM"
text = re.sub(r'```json\s*|\s*```', '', text)
try:
data = json.loads(text)
except json.JSONDecodeError:
json_match = re.search(r'\{.*\}', text, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
else:
raise
score = float(data.get("score", 0))
feedback = data.get("feedback", "")
score = max(0, min(score, 1))
return score, feedback
except Exception as e:
logger.error(f"Failed to parse LLM output: {e}, text type: {type(text)}")
return 0, "Failed to parse AI grading"
def ai_semantic_grade(self, question, student, correct, max_score, course):
"""
Grade an answer using AI with context from Qdrant.
Args: question: The question text // student: Student's answer // correct: Correct answer
max_score: Maximum score for this question // course: Optional course for filtering context
Returns: // Tuple of (score, feedback)
"""
try:
# Retrieve context filtered by username and course
context = self.retrieve_context(question, course)
prompt = self.build_prompt(question,student,correct,context)
response = self.llm.generate_text(prompt)
# Log response type for debugging
logger.info(f"Response type: {type(response)}")
score_ratio, feedback = self.parse_llm_output(response)
score = score_ratio * max_score
return score, feedback
except Exception as e:
logger.error(f"AI grading failed: {e}")
# Fallback to simple similarity
similarity = self.simple_similarity(student, correct)
return similarity * max_score, f"Fallback similarity grading: {similarity:.2f}"
@shared_task
def grade_exam_task(submission_dict: Dict[str, Any]):
submission = None
try:
submission = ExamSubmission(**submission_dict)
service = ExamGradingService()
result = service.grade_submission(submission)
result_dict = result.model_dump()
# Send webhook with grade only
try:
webhook_url = get_settings().GRADE_WEBHOOK_URL
print(f" Webhook URL: {webhook_url}")
if webhook_url:
# Create grade-only payload
grade_only_payload = {
"status": "completed",
"exam_id": submission.exam_id,
"student_id": submission.student_id,
"course_id":submission.course_id,
"grade": {
"total_score": result_dict['total_score'],
"max_total_score": result_dict['max_total_score'],
"percentage": result_dict['percentage'],
"grade": result_dict['grade'],
"graded_time": result_dict['graded_time']
},
"result" : result_dict,
}
response = httpx.post(
webhook_url,
json=grade_only_payload,
timeout=30.0
)
print(f" Response status: {response.status_code}")
if response.status_code == 200:
print(" Grade-only webhook sent successfully!")
else:
print(f" Webhook returned status: {response.status_code}")
print(f" Response: {response.text[:200]}")
else:
print("WEBHOOK_URL is empty or not set!")
except Exception as e:
print(f" Webhook error: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
print(" Task completed successfully")
return result_dict
except Exception as e:
print(f" ERROR in task: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise