Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import csv | |
| import io | |
| import re | |
| from datetime import datetime | |
| from typing import Any | |
| from fastapi import APIRouter, File, Form, HTTPException, UploadFile | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from app.services.comparison import run_comparison, run_csv_comparison, stream_comparison | |
| router = APIRouter() | |
| class CompareRequest(BaseModel): | |
| query: str | |
| neon_selections: list[dict[str, str]] | |
| comparison_model_ids: list[str] = [] | |
| session_id: str | None = None | |
| persona_target: str = "neon-only" | |
| async def compare(req: CompareRequest): | |
| if not req.neon_selections: | |
| raise HTTPException(status_code=400, detail="Select at least one Neon model") | |
| groups = await run_comparison( | |
| query=req.query, | |
| neon_selections=req.neon_selections, | |
| comparison_model_ids=req.comparison_model_ids, | |
| session_id=req.session_id, | |
| persona_target=req.persona_target, | |
| ) | |
| return {"groups": groups} | |
| async def compare_stream(req: CompareRequest): | |
| if not req.neon_selections: | |
| raise HTTPException(status_code=400, detail="Select at least one Neon model") | |
| return StreamingResponse( | |
| stream_comparison( | |
| query=req.query, | |
| neon_selections=req.neon_selections, | |
| comparison_model_ids=req.comparison_model_ids, | |
| session_id=req.session_id, | |
| persona_target=req.persona_target, | |
| ), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| def _strip_markdown(text: str) -> str: | |
| """Remove common markdown formatting from text for CSV output.""" | |
| text = re.sub(r'#{1,6}\s+', '', text) | |
| text = re.sub(r'\*\*\*(.+?)\*\*\*', r'\1', text) | |
| text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) | |
| text = re.sub(r'\*(.+?)\*', r'\1', text) | |
| text = re.sub(r'__(.+?)__', r'\1', text) | |
| text = re.sub(r'_(.+?)_', r'\1', text) | |
| text = re.sub(r'~~(.+?)~~', r'\1', text) | |
| text = re.sub(r'`{1,3}(.+?)`{1,3}', r'\1', text, flags=re.DOTALL) | |
| text = re.sub(r'^\s*[-*+]\s+', '', text, flags=re.MULTILINE) | |
| text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) | |
| text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text) | |
| return text.strip() | |
| async def compare_csv( | |
| file: UploadFile = File(...), | |
| neon_selections: str = Form(...), | |
| comparison_model_ids: str = Form(""), | |
| persona_target: str = Form("neon-only"), | |
| ): | |
| import json | |
| try: | |
| neon_sel = json.loads(neon_selections) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid neon_selections JSON") | |
| comp_ids = [] | |
| if comparison_model_ids: | |
| try: | |
| comp_ids = json.loads(comparison_model_ids) | |
| except json.JSONDecodeError: | |
| comp_ids = [s.strip() for s in comparison_model_ids.split(",") if s.strip()] | |
| content = await file.read() | |
| text = content.decode("utf-8-sig") | |
| reader = csv.reader(io.StringIO(text)) | |
| questions = [] | |
| for row in reader: | |
| if row and row[0].strip(): | |
| q = row[0].strip() | |
| if q.lower() not in ("question", "questions", "prompt", "prompts"): | |
| questions.append(q) | |
| if not questions: | |
| raise HTTPException(status_code=400, detail="CSV contains no questions") | |
| all_results = await run_csv_comparison( | |
| questions=questions, | |
| neon_selections=neon_sel, | |
| comparison_model_ids=comp_ids, | |
| persona_target=persona_target, | |
| ) | |
| model_keys = _collect_model_keys(all_results) | |
| output = io.StringIO() | |
| output.write('\ufeff') | |
| writer = csv.writer(output) | |
| writer.writerow(["Question"] + model_keys) | |
| for question, groups in zip(questions, all_results): | |
| row_data: dict[str, str] = {} | |
| for group in groups: | |
| persona = group.get("neon_persona", "") | |
| for r in group["responses"]: | |
| key = _response_key(r, persona) | |
| row_data[key] = _strip_markdown(r["response"]) | |
| writer.writerow([question] + [row_data.get(k, "") for k in model_keys]) | |
| original_name = file.filename or "upload" | |
| stem = original_name.rsplit(".", 1)[0][:12] | |
| now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| filename = f"{stem} LLM Comparison {now}.csv" | |
| output.seek(0) | |
| return StreamingResponse( | |
| iter([output.getvalue().encode("utf-8")]), | |
| media_type="text/csv; charset=utf-8", | |
| headers={"Content-Disposition": f"attachment; filename={filename}"}, | |
| ) | |
| def _response_key(r: dict[str, Any], neon_persona: str = "") -> str: | |
| if r.get("is_neon"): | |
| return f"Neon / {r['model_name']} ({r.get('persona_name', '')})" | |
| label = f"{r.get('provider_name', r['provider'])} / {r['model_name']}" | |
| if neon_persona: | |
| label += f" - {neon_persona}" | |
| return label | |
| def _collect_model_keys(all_results: list) -> list[str]: | |
| keys: list[str] = [] | |
| seen: set[str] = set() | |
| for groups in all_results: | |
| for group in groups: | |
| persona = group.get("neon_persona", "") | |
| for r in group["responses"]: | |
| k = _response_key(r, persona) | |
| if k not in seen: | |
| keys.append(k) | |
| seen.add(k) | |
| return keys | |