abhikamuni's picture
l
dcafa2f verified
import os
# --- KEEP THE SENTENCE-TRANSFORMERS FIX ---
# We set the HF cache to the folder we create in the Dockerfile
os.environ["HF_HOME"] = "/code/.cache/huggingface"
# --- REMOVED THE DSPY ENV VAR ---
# --- ALL OTHER IMPORTS MUST BE BELOW THIS FIX ---
import json
import uuid
from datetime import datetime
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict, Any
# Import our new modular services
# We need to make sure the path is correct
from app.services.guardrails import check_input_guardrail, check_output_guardrail
from app.services.rag_pipeline import generate_solution
# --- REMOVED THE DSPY IMPORT ---
# from app.services.dspy_feedback import refine_solution_with_dspy
from app.schemas import (
AskRequest, AskResponse, FeedbackRequest, FeedbackResponse
)
# Initialize FastAPI
app = FastAPI(title="Math Routing Agent (Stateless HITL Version)")
CLIENT_URL = os.getenv("FRONTEND_URL", "http://localhost:3000")
# --- CORS Middleware ---
app.add_middleware(
CORSMiddleware,
allow_origins=[CLIENT_URL, "https://*.hf.space"], # Allows React app
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- API Endpoints ---
@app.post("/ask/", response_model=AskResponse)
async def ask_math_question(request: AskRequest):
is_safe, reason = check_input_guardrail(request.question)
if not is_safe:
raise HTTPException(status_code=400, detail=f"Input blocked: {reason}")
try:
solution, source = await generate_solution(request.question)
except Exception as e:
print(f"--- Main Error (generate_solution): {e} ---")
raise HTTPException(status_code=500, detail="Agent failed to process.")
is_safe, message = check_output_guardrail(solution)
if not is_safe:
raise HTTPException(status_code=500, detail=f"Output blocked: {message}")
return AskResponse(
solution=message,
source=source,
thread_id=str(uuid.uuid4()), # New ID for this "turn"
question=request.question
)
@app.post("/feedback/", response_model=FeedbackResponse, status_code=200)
async def give_feedback(request: FeedbackRequest):
"""
Endpoint to receive feedback.
The DSPy refinement has been removed.
"""
print(f"--- HITL: Received Feedback for {request.thread_id} ---")
# 1. Log the feedback
try:
feedback_entry = request.model_dump()
feedback_entry["timestamp"] = datetime.utcnow().isoformat()
print(f"--- HITL_FEEDBACK_ENTRY: {json.dumps(feedback_entry)}")
print("--- HITL: Feedback logged to console. ---")
except Exception as e:
print(f"--- HITL: Error logging feedback: {e} ---")
# --- START: MODIFIED SECTION ---
# Since dspy is removed, we just log the feedback and
# return the original solution, no matter if it's "good" or "bad".
if request.rating == "bad" and request.feedback_text:
print(f"--- HITL: Rating is 'bad'. Logging only (refinement disabled). ---")
else:
print("--- HITL: Rating is 'good'. Logging only. ---")
return FeedbackResponse(
solution=request.original_solution,
source="feedback_logged",
thread_id=request.thread_id,
question=request.question
)
# --- END: MODIFIED SECTION ---
@app.get("/")
def read_root():
return {"Hello": "Math Agent API is running (Stateless HITL Version)."}