|
|
""" |
|
|
Flask-SocketIO server for the PIPS front-end. |
|
|
|
|
|
Matches the JS events used in index.html: |
|
|
β’ session_connected |
|
|
β’ settings_updated |
|
|
β’ solving_started / step_update / llm_streaming_* / code_execution_* / code_check |
|
|
β’ solving_complete / solving_error / solving_interrupted |
|
|
β’ heartbeat_response |
|
|
β’ download_chat_log |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import os |
|
|
import threading |
|
|
import time |
|
|
from datetime import datetime |
|
|
from typing import Any, Dict |
|
|
|
|
|
from flask import Flask, render_template, request, jsonify |
|
|
from flask_socketio import SocketIO, emit |
|
|
|
|
|
|
|
|
from .models import AVAILABLE_MODELS, get_model |
|
|
from .core import PIPSSolver, PIPSMode |
|
|
from .utils import RawInput, base642img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__, template_folder="templates") |
|
|
app.config["SECRET_KEY"] = "change-me" |
|
|
socketio = SocketIO(app, cors_allowed_origins="*") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_SETTINGS = dict( |
|
|
model = next(iter(AVAILABLE_MODELS)), |
|
|
openai_api_key = "", |
|
|
google_api_key = "", |
|
|
anthropic_api_key = "", |
|
|
max_iterations = 8, |
|
|
temperature = 0.0, |
|
|
max_tokens = 4096, |
|
|
max_execution_time = 10, |
|
|
|
|
|
pips_mode = "AGENT", |
|
|
generator_model = next(iter(AVAILABLE_MODELS)), |
|
|
critic_model = next(iter(AVAILABLE_MODELS)), |
|
|
custom_rules = "", |
|
|
prompt_overrides = {}, |
|
|
) |
|
|
|
|
|
sessions: Dict[str, Dict[str, Any]] = {} |
|
|
active_tasks: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
def _safe(obj): |
|
|
"""JSON-serialise anything (fractions etc. become strings).""" |
|
|
if obj is None or isinstance(obj, (str, int, float, bool)): |
|
|
return obj |
|
|
if isinstance(obj, list): |
|
|
return [_safe(x) for x in obj] |
|
|
if isinstance(obj, dict): |
|
|
return {k: _safe(v) for k, v in obj.items()} |
|
|
return str(obj) |
|
|
|
|
|
|
|
|
def make_callbacks(sid: str, generator_model_name: str, critic_model_name: str, stop_evt: threading.Event, max_exec: int): |
|
|
"""Build the callbacks dict required by PIPSSolver (stream=True).""" |
|
|
|
|
|
def _emit(event: str, payload: dict): |
|
|
|
|
|
if event == "llm_streaming_token": |
|
|
print(f"[DEBUG] Emitting token for session {sid}: '{payload.get('token', '')[:20]}...'") |
|
|
elif event == "code_check_streaming_token": |
|
|
print(f"[DEBUG] Emitting code reviewer token for session {sid}: '{payload.get('token', '')[:20]}...'") |
|
|
else: |
|
|
print(f"[DEBUG] Emitting {event} for session {sid}") |
|
|
socketio.emit(event, payload, room=sid) |
|
|
|
|
|
socketio.sleep(0) |
|
|
|
|
|
cb = dict( |
|
|
|
|
|
on_step_update=lambda step, msg, iteration=None, prompt_details=None, **_: _emit( |
|
|
"step_update", dict(step=step, message=msg, iteration=iteration, prompt_details=prompt_details) |
|
|
), |
|
|
|
|
|
|
|
|
on_llm_streaming_start=lambda it, m: _emit( |
|
|
"llm_streaming_start", dict(iteration=it, model_name=generator_model_name) |
|
|
), |
|
|
on_llm_streaming_token=lambda tok, it, m: _emit( |
|
|
"llm_streaming_token", dict(token=tok, iteration=it, model_name=generator_model_name) |
|
|
), |
|
|
on_llm_streaming_end=lambda it, m: _emit( |
|
|
"llm_streaming_end", dict(iteration=it, model_name=generator_model_name) |
|
|
), |
|
|
|
|
|
|
|
|
on_code_check_streaming_start=lambda it, m: _emit( |
|
|
"code_check_streaming_start", dict(iteration=it, model_name=critic_model_name) |
|
|
), |
|
|
on_code_check_streaming_token=lambda tok, it, m: _emit( |
|
|
"code_check_streaming_token", dict(token=tok, iteration=it, model_name=critic_model_name) |
|
|
), |
|
|
on_code_check_streaming_end=lambda it, m: _emit( |
|
|
"code_check_streaming_end", dict(iteration=it, model_name=critic_model_name) |
|
|
), |
|
|
|
|
|
|
|
|
on_code_execution_start=lambda it: _emit( |
|
|
"code_execution_start", dict(iteration=it) |
|
|
), |
|
|
on_code_execution_end=lambda it: _emit( |
|
|
"code_execution_end", dict(iteration=it) |
|
|
), |
|
|
on_code_execution=lambda it, out, stdout, err: _emit( |
|
|
"code_execution", |
|
|
dict(iteration=it, output=str(out), stdout=stdout, error=err), |
|
|
), |
|
|
|
|
|
|
|
|
|
|
|
on_error=lambda msg: _emit("solving_error", dict(error=msg)), |
|
|
|
|
|
|
|
|
check_interrupted=stop_evt.is_set, |
|
|
get_max_execution_time=lambda: max_exec, |
|
|
|
|
|
|
|
|
on_waiting_for_user=lambda iteration, critic_text, code, symbols: _emit( |
|
|
"awaiting_user_feedback", |
|
|
dict(iteration=iteration, critic_text=critic_text, code=code, symbols=_safe(symbols)) |
|
|
), |
|
|
) |
|
|
return cb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/") |
|
|
def index(): |
|
|
return render_template( |
|
|
"index_modular.html", |
|
|
available_models=AVAILABLE_MODELS, |
|
|
default_settings=DEFAULT_SETTINGS, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@socketio.on("connect") |
|
|
def on_connect(): |
|
|
sid = request.sid |
|
|
sessions[sid] = dict(settings=DEFAULT_SETTINGS.copy(), chat=[]) |
|
|
emit("session_connected", {"session_id": sid}) |
|
|
print(f"[CONNECT] {sid}") |
|
|
|
|
|
|
|
|
@socketio.on("disconnect") |
|
|
def on_disconnect(): |
|
|
sid = request.sid |
|
|
if sid in active_tasks: |
|
|
active_tasks[sid]["event"].set() |
|
|
active_tasks.pop(sid, None) |
|
|
sessions.pop(sid, None) |
|
|
print(f"[DISCONNECT] {sid}") |
|
|
|
|
|
|
|
|
@socketio.on("update_settings") |
|
|
def on_update_settings(data): |
|
|
sid = request.sid |
|
|
if sid not in sessions: |
|
|
emit("settings_updated", {"status": "error", "message": "No session"}) |
|
|
return |
|
|
|
|
|
sessions[sid]["settings"].update(data) |
|
|
emit("settings_updated", {"status": "success", "settings": sessions[sid]["settings"]}) |
|
|
|
|
|
|
|
|
@socketio.on("solve_problem") |
|
|
def on_solve_problem(data): |
|
|
sid = request.sid |
|
|
if sid not in sessions: |
|
|
emit("solving_error", {"error": "Session vanished"}) |
|
|
return |
|
|
|
|
|
text = (data.get("text") or "").strip() |
|
|
if not text: |
|
|
emit("solving_error", {"error": "Problem text is empty"}) |
|
|
return |
|
|
|
|
|
img_b64 = data.get("image") |
|
|
img = None |
|
|
if img_b64 and img_b64.startswith("data:image"): |
|
|
try: |
|
|
img = base642img(img_b64.split(",", 1)[1]) |
|
|
except Exception as e: |
|
|
emit("solving_error", {"error": f"Bad image: {e}"}) |
|
|
return |
|
|
|
|
|
settings = sessions[sid]["settings"] |
|
|
generator_model_id = settings.get("generator_model", settings["model"]) |
|
|
critic_model_id = settings.get("critic_model", settings["model"]) |
|
|
pips_mode = settings.get("pips_mode", "AGENT") |
|
|
|
|
|
global_rules = settings.get("global_rules", "") |
|
|
session_rules = settings.get("session_rules", "") |
|
|
legacy_custom_rules = settings.get("custom_rules", "") |
|
|
|
|
|
|
|
|
combined_rules = [] |
|
|
if global_rules: |
|
|
combined_rules.append(f"Global Rules:\n{global_rules}") |
|
|
if session_rules: |
|
|
combined_rules.append(f"Session Rules:\n{session_rules}") |
|
|
if legacy_custom_rules and not global_rules and not session_rules: |
|
|
|
|
|
combined_rules.append(legacy_custom_rules) |
|
|
|
|
|
custom_rules = "\n\n".join(combined_rules) |
|
|
|
|
|
print(f"[DEBUG] Custom rules processing for session {sid}:") |
|
|
print(f" Global rules: {repr(global_rules)}") |
|
|
print(f" Session rules: {repr(session_rules)}") |
|
|
print(f" Legacy rules: {repr(legacy_custom_rules)}") |
|
|
print(f" Combined rules: {repr(custom_rules)}") |
|
|
|
|
|
|
|
|
def get_api_key_for_model(model_id): |
|
|
if any(model_id.startswith(model) for model in ["gpt", "o3", "o4"]): |
|
|
return settings.get("openai_api_key") |
|
|
elif "gemini" in model_id: |
|
|
return settings.get("google_api_key") |
|
|
elif "claude" in model_id: |
|
|
return settings.get("anthropic_api_key") |
|
|
return None |
|
|
|
|
|
|
|
|
generator_api_key = get_api_key_for_model(generator_model_id) |
|
|
critic_api_key = get_api_key_for_model(critic_model_id) |
|
|
|
|
|
if not generator_api_key: |
|
|
emit("solving_error", {"error": f"API key missing for generator model: {generator_model_id}"}) |
|
|
return |
|
|
|
|
|
stop_evt = threading.Event() |
|
|
|
|
|
def task(): |
|
|
try: |
|
|
print(f"[DEBUG] Starting solving task for session {sid}") |
|
|
|
|
|
sample = RawInput(text_input=text, image_input=img) |
|
|
|
|
|
|
|
|
generator_model = get_model(generator_model_id, generator_api_key) |
|
|
|
|
|
cbs = make_callbacks( |
|
|
sid, generator_model_id, critic_model_id, stop_evt, settings["max_execution_time"] |
|
|
) |
|
|
|
|
|
print(f"[DEBUG] Emitting solving_started for session {sid}") |
|
|
socketio.emit("solving_started", {}, room=sid) |
|
|
socketio.sleep(0) |
|
|
|
|
|
critic_model = generator_model |
|
|
if critic_model_id != generator_model_id: |
|
|
if critic_api_key: |
|
|
critic_model = get_model(critic_model_id, critic_api_key) |
|
|
else: |
|
|
print(f"[DEBUG] Critic API key missing for {critic_model_id}; falling back to generator model for criticism.") |
|
|
|
|
|
requested_interactive = (pips_mode == "INTERACTIVE") |
|
|
solver = PIPSSolver( |
|
|
generator_model, |
|
|
max_iterations=settings["max_iterations"], |
|
|
temperature=settings["temperature"], |
|
|
max_tokens=settings["max_tokens"], |
|
|
interactive=requested_interactive, |
|
|
critic_model=critic_model, |
|
|
) |
|
|
|
|
|
decision_max_tokens = min(1024, settings["max_tokens"]) |
|
|
answer, logs, mode_decision_summary = solver.solve( |
|
|
sample, |
|
|
stream=True, |
|
|
callbacks=cbs, |
|
|
additional_rules=custom_rules, |
|
|
decision_max_tokens=decision_max_tokens, |
|
|
interactive_requested=requested_interactive, |
|
|
) |
|
|
|
|
|
use_code = mode_decision_summary.get("use_code", False) |
|
|
if sid in sessions: |
|
|
sessions[sid]["mode_decision"] = mode_decision_summary |
|
|
print( |
|
|
f"[DEBUG] Mode decision for session {sid}: " |
|
|
f"use_code={use_code}, requested_interactive={requested_interactive}" |
|
|
) |
|
|
|
|
|
if use_code and critic_model_id != generator_model_id and not critic_api_key: |
|
|
cbs["on_step_update"]( |
|
|
"mode_selection", |
|
|
"Proceeding without a dedicated critic model because no API key was provided.", |
|
|
iteration=None, |
|
|
) |
|
|
|
|
|
if use_code: |
|
|
print(f"[DEBUG] Used iterative code path for session {sid}") |
|
|
|
|
|
if requested_interactive and not answer and solver._checkpoint: |
|
|
if sid in sessions: |
|
|
sessions[sid]["solver"] = solver |
|
|
print(f"[DEBUG] Interactive mode - waiting for user feedback for session {sid}") |
|
|
return |
|
|
else: |
|
|
print(f"[DEBUG] Used chain-of-thought path for session {sid}") |
|
|
|
|
|
if stop_evt.is_set(): |
|
|
print(f"[DEBUG] Task was interrupted for session {sid}") |
|
|
socketio.emit("solving_interrupted", {"message": "Interrupted"}, room=sid) |
|
|
return |
|
|
|
|
|
print(f"[DEBUG] Solving completed, emitting final answer for session {sid}") |
|
|
|
|
|
if not isinstance(logs, dict) or logs is None: |
|
|
logs = {} |
|
|
if isinstance(logs, dict): |
|
|
logs.setdefault("mode_decision", mode_decision_summary) |
|
|
|
|
|
|
|
|
latest_symbols = logs.get("all_symbols", [])[-1] if logs.get("all_symbols") else {} |
|
|
latest_code = logs.get("all_programs", [])[-1] if logs.get("all_programs") else "" |
|
|
|
|
|
|
|
|
socketio.emit("final_artifacts", { |
|
|
"symbols": _safe(latest_symbols), |
|
|
"code": latest_code |
|
|
}, room=sid) |
|
|
|
|
|
socketio.emit( |
|
|
"solving_complete", |
|
|
{ |
|
|
"final_answer": answer, |
|
|
"logs": _safe(logs), |
|
|
"method": "iterative_code" if use_code else "chain_of_thought", |
|
|
}, |
|
|
room=sid, |
|
|
) |
|
|
if sid in sessions: |
|
|
sessions[sid].pop("mode_decision", None) |
|
|
|
|
|
except Exception as exc: |
|
|
print(f"[DEBUG] Exception in solving task for session {sid}: {exc}") |
|
|
if sid in sessions: |
|
|
sessions[sid].pop("mode_decision", None) |
|
|
socketio.emit("solving_error", {"error": str(exc)}, room=sid) |
|
|
finally: |
|
|
print(f"[DEBUG] Cleaning up task for session {sid}") |
|
|
active_tasks.pop(sid, None) |
|
|
|
|
|
active_tasks[sid] = dict(event=stop_evt, task=socketio.start_background_task(task)) |
|
|
|
|
|
|
|
|
@socketio.on("interrupt_solving") |
|
|
def on_interrupt(data=None): |
|
|
sid = request.sid |
|
|
if sid in active_tasks: |
|
|
active_tasks[sid]["event"].set() |
|
|
emit("solving_interrupted", {"message": "Stopped."}) |
|
|
else: |
|
|
emit("solving_interrupted", {"message": "No active task."}) |
|
|
|
|
|
|
|
|
@socketio.on("provide_feedback") |
|
|
def on_provide_feedback(data): |
|
|
"""Handle user feedback in interactive mode.""" |
|
|
sid = request.sid |
|
|
if sid not in sessions: |
|
|
emit("solving_error", {"error": "Session vanished"}) |
|
|
return |
|
|
|
|
|
solver = sessions[sid].get("solver") |
|
|
if not solver or not solver._checkpoint: |
|
|
emit("solving_error", {"error": "No interactive session waiting for feedback"}) |
|
|
return |
|
|
|
|
|
|
|
|
user_feedback = { |
|
|
"accept_critic": data.get("accept_critic", True), |
|
|
"extra_comments": data.get("extra_comments", ""), |
|
|
"quoted_ranges": data.get("quoted_ranges", []), |
|
|
"terminate": data.get("terminate", False) |
|
|
} |
|
|
|
|
|
def continue_task(): |
|
|
try: |
|
|
print(f"[DEBUG] Continuing interactive task with user feedback for session {sid}") |
|
|
|
|
|
|
|
|
answer, logs = solver.continue_from_checkpoint(user_feedback) |
|
|
|
|
|
mode_decision = sessions[sid].get("mode_decision") or getattr(solver, "_mode_decision_summary", None) |
|
|
if not isinstance(logs, dict) or logs is None: |
|
|
logs = {} |
|
|
if isinstance(logs, dict) and mode_decision: |
|
|
logs.setdefault("mode_decision", mode_decision) |
|
|
|
|
|
|
|
|
latest_symbols = logs.get("all_symbols", [])[-1] if logs.get("all_symbols") else {} |
|
|
latest_code = logs.get("all_programs", [])[-1] if logs.get("all_programs") else "" |
|
|
|
|
|
|
|
|
socketio.emit("final_artifacts", { |
|
|
"symbols": _safe(latest_symbols), |
|
|
"code": latest_code |
|
|
}, room=sid) |
|
|
|
|
|
|
|
|
socketio.emit("solving_complete", { |
|
|
"final_answer": answer, |
|
|
"logs": _safe(logs), |
|
|
"method": "iterative_code_interactive", |
|
|
}, room=sid) |
|
|
sessions[sid].pop("mode_decision", None) |
|
|
|
|
|
except Exception as exc: |
|
|
print(f"[DEBUG] Exception in interactive continuation for session {sid}: {exc}") |
|
|
socketio.emit("solving_error", {"error": str(exc)}, room=sid) |
|
|
if sid in sessions: |
|
|
sessions[sid].pop("mode_decision", None) |
|
|
finally: |
|
|
|
|
|
if sid in sessions: |
|
|
sessions[sid].pop("solver", None) |
|
|
active_tasks.pop(sid, None) |
|
|
|
|
|
|
|
|
active_tasks[sid] = dict(event=threading.Event(), task=socketio.start_background_task(continue_task)) |
|
|
|
|
|
|
|
|
@socketio.on("terminate_session") |
|
|
def on_terminate_session(data=None): |
|
|
"""Handle user termination of interactive session.""" |
|
|
sid = request.sid |
|
|
if sid not in sessions: |
|
|
emit("solving_error", {"error": "Session vanished"}) |
|
|
return |
|
|
|
|
|
solver = sessions[sid].get("solver") |
|
|
if not solver or not solver._checkpoint: |
|
|
emit("solving_error", {"error": "No interactive session to terminate"}) |
|
|
return |
|
|
|
|
|
|
|
|
user_feedback = {"terminate": True} |
|
|
|
|
|
def terminate_task(): |
|
|
try: |
|
|
print(f"[DEBUG] Terminating interactive task for session {sid}") |
|
|
|
|
|
|
|
|
answer, logs = solver.continue_from_checkpoint(user_feedback) |
|
|
|
|
|
mode_decision = sessions[sid].get("mode_decision") or getattr(solver, "_mode_decision_summary", None) |
|
|
if not isinstance(logs, dict) or logs is None: |
|
|
logs = {} |
|
|
if isinstance(logs, dict) and mode_decision: |
|
|
logs.setdefault("mode_decision", mode_decision) |
|
|
|
|
|
|
|
|
latest_symbols = logs.get("all_symbols", [])[-1] if logs.get("all_symbols") else {} |
|
|
latest_code = logs.get("all_programs", [])[-1] if logs.get("all_programs") else "" |
|
|
|
|
|
|
|
|
socketio.emit("final_artifacts", { |
|
|
"symbols": _safe(latest_symbols), |
|
|
"code": latest_code |
|
|
}, room=sid) |
|
|
|
|
|
|
|
|
socketio.emit("solving_complete", { |
|
|
"final_answer": answer, |
|
|
"logs": _safe(logs), |
|
|
"method": "iterative_code_interactive_terminated", |
|
|
}, room=sid) |
|
|
sessions[sid].pop("mode_decision", None) |
|
|
|
|
|
except Exception as exc: |
|
|
print(f"[DEBUG] Exception in interactive termination for session {sid}: {exc}") |
|
|
socketio.emit("solving_error", {"error": str(exc)}, room=sid) |
|
|
if sid in sessions: |
|
|
sessions[sid].pop("mode_decision", None) |
|
|
finally: |
|
|
|
|
|
if sid in sessions: |
|
|
sessions[sid].pop("solver", None) |
|
|
active_tasks.pop(sid, None) |
|
|
|
|
|
|
|
|
active_tasks[sid] = dict(event=threading.Event(), task=socketio.start_background_task(terminate_task)) |
|
|
|
|
|
|
|
|
@socketio.on("switch_mode") |
|
|
def on_switch_mode(data): |
|
|
"""Handle switching between AGENT and INTERACTIVE modes.""" |
|
|
sid = request.sid |
|
|
if sid not in sessions: |
|
|
emit("solving_error", {"error": "Session vanished"}) |
|
|
return |
|
|
|
|
|
new_mode = data.get("mode", "AGENT") |
|
|
if new_mode not in ["AGENT", "INTERACTIVE"]: |
|
|
emit("solving_error", {"error": "Invalid mode"}) |
|
|
return |
|
|
|
|
|
|
|
|
sessions[sid]["settings"]["pips_mode"] = new_mode |
|
|
|
|
|
emit("mode_switched", {"mode": new_mode}) |
|
|
|
|
|
|
|
|
@socketio.on("heartbeat") |
|
|
def on_heartbeat(data): |
|
|
emit("heartbeat_response", {"timestamp": data.get("timestamp"), "server_time": time.time()}) |
|
|
|
|
|
|
|
|
@socketio.on("download_chat_log") |
|
|
def on_download_chat_log(): |
|
|
sid = request.sid |
|
|
sess = sessions.get(sid) |
|
|
if not sess: |
|
|
emit("error", {"message": "Session missing"}) |
|
|
return |
|
|
|
|
|
payload = dict( |
|
|
session_id=sid, |
|
|
timestamp=datetime.utcnow().isoformat(), |
|
|
settings=_safe(sess["settings"]), |
|
|
chat_history=_safe(sess["chat"]), |
|
|
) |
|
|
emit( |
|
|
"chat_log_ready", |
|
|
{ |
|
|
"filename": f"pips_chat_{sid[:8]}.json", |
|
|
"content": json.dumps(payload, indent=2), |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_app(host: str = "0.0.0.0", port: int = 8080, debug: bool = False): |
|
|
os.makedirs("uploads", exist_ok=True) |
|
|
socketio.run(app, host=host, port=port, debug=debug) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--host", default="0.0.0.0") |
|
|
ap.add_argument("--port", type=int, default=8080) |
|
|
ap.add_argument("--debug", action="store_true") |
|
|
args = ap.parse_args() |
|
|
run_app(args.host, args.port, args.debug) |
|
|
|