import sys import os import json import logging import asyncio from collections import defaultdict from fastapi import APIRouter, Query from pydantic import BaseModel from sse_starlette.sse import EventSourceResponse sys.path.append(os.path.join(os.path.dirname(__file__), "..")) logger = logging.getLogger(__name__) router = APIRouter() from typing import Dict, List, Any import asyncio RUNNING_TASKS: Dict[str, asyncio.Task[Any]] = {} SUBSCRIBERS: Dict[str, List[asyncio.Queue[Any]]] = defaultdict(list) async def run_generator_task( project_id: str, namespace: str, document_type: str, resume: bool = False ): logger.info(f"Rozpoczynam tlo generatora dla projektu {project_id}") try: from agents.generator_agent import DocumentGeneratorAgent, GeneratorState from core.document_builder import DocumentBuilder agent = DocumentGeneratorAgent() project_description = "" db_plan = [] company_context = "" try: from core.subscription.db import SessionLocal from core.projects.models import ( Project, ProjectSection, ProjectSectionTemplate, ) _db = SessionLocal() _proj = _db.query(Project).filter(Project.id == project_id).first() if _proj: parts = [] if hasattr(_proj, "description") and _proj.description: parts.append(_proj.description) if hasattr(_proj, "nip") and _proj.nip: parts.append(f"NIP: {_proj.nip}") # Fetch company name from attributes or external_context c_name = getattr(_proj, "company_name", None) if ( not c_name and _proj.external_context and "company_data" in _proj.external_context ): c_name = _proj.external_context["company_data"].get("name") if c_name: parts.append( f"DANE OBOWIĄZKOWE: Wnioskodawcą jest firma: {c_name}. Bezwzględnie używaj tej nazwy firmy (lub jej zanonimizowanego tokenu) pisząc sekcje wniosku, aby uniknąć bezosobowego tonu." ) if _proj.external_context: import json parts.append("DODATKOWY KONTEKST ZEWNĘTRZNY (np. dane GUS, wpisane przez użytkownika informacje):") parts.append(json.dumps(_proj.external_context, ensure_ascii=False, indent=2)) project_description = "\n".join(parts) db_sections = ( _db.query(ProjectSection) .filter(ProjectSection.project_id == project_id) .order_by(ProjectSection.order) .all() ) if db_sections: templates = ( _db.query(ProjectSectionTemplate) .filter( ProjectSectionTemplate.program_type == _proj.program_type ) .all() ) from endpoints.projects import UNIVERSAL_FALLBACK_MAP tmpl_map = {t.section_type: t.title for t in templates} for sec in db_sections: title = tmpl_map.get(sec.section_type) if not title: title = UNIVERSAL_FALLBACK_MAP.get(sec.section_type, sec.section_type) db_plan.append({"type": sec.section_type, "title": title}) if _proj.external_context and "company_data" in _proj.external_context: company_context = json.dumps(_proj.external_context["company_data"], ensure_ascii=False, indent=2) _db.close() except Exception as desc_err: logger.debug( f"[Generator] Nie udało się wczytać opisu projektu: {desc_err}" ) initial_state: GeneratorState = { "project_id": project_id, "namespace": namespace, "document_type": document_type, "project_description": project_description, "sections_plan": db_plan, "current_section_idx": 0, "generated_sections": {}, "context": "", "is_completed": False, "additional_context": f"BEZWZGLĘDNE ŹRÓDŁO PRAWDY O FIRMIE (DANE Z GUS/KRS):\n{company_context}\nZakaz wymyślania innych danych o firmie!" if company_context else "", "missing_data_question": "", "traceability_data": {} } last_state = dict(initial_state) async def broadcast(msg: dict): for q in SUBSCRIBERS.get(project_id, []): try: await q.put(msg) except Exception: pass async for event in agent.astm_stream( initial_state, thread_id=project_id, resume=resume ): kind = event.get("event", "") # Powiadomienie o początku generowania sekcji if kind == "on_chain_start" and event.get("name") == "draft_section": plan = last_state.get("sections_plan", []) idx = last_state.get("current_section_idx", 0) if plan and idx < len(plan): section = plan[idx] section_title = ( section["title"] if isinstance(section, dict) else section ) await broadcast( { "event": "section_started", "data": section_title, } ) # Powiadomienie o ukończeniu sekcji elif kind == "on_chain_end" and event.get("name") == "draft_section": output = event.get("data", {}).get("output", {}) if "generated_sections" in output: last_state.update(output) completed_idx = output.get("current_section_idx", 1) - 1 plan = last_state.get("sections_plan", []) if completed_idx >= 0 and completed_idx < len(plan): section = plan[completed_idx] section_title = ( section["title"] if isinstance(section, dict) else section ) section_content = output["generated_sections"].get( section_title, "" ) sec_type = ( section["type"] if isinstance(section, dict) else section_title.lower().replace(" ", "_") ) # ZAPIS CZASTKOWY try: from core.subscription.db import SessionLocal from core.projects.models import ProjectSection tz_db = SessionLocal() db_sec = ( tz_db.query(ProjectSection) .filter( ProjectSection.project_id == project_id, ProjectSection.section_type == sec_type, ) .first() ) if db_sec: db_sec.content = section_content db_sec.generated_by_ai = True db_sec.is_approved = False tz_db.commit() tz_db.close() except Exception as e: logger.error(f"[Generator] partial save failed: {e}") await broadcast( { "event": "section_completed", "data": json.dumps( { "title": section_title, "content": section_content, "index": completed_idx + 1, } ), } ) # Śledzenie zmian stanu planu elif kind == "on_chain_end" and event.get("name") == "plan_document": output = event.get("data", {}).get("output", {}) last_state.update(output) # Sprawdzenie czy graf nie został zapauzowany (HIL) if last_state.get("missing_data_question"): logger.info( f"Graf zatrzymany. Oczekiwanie na dane dla projektu {project_id}." ) await broadcast( { "event": "waiting_for_user_input", "data": json.dumps( { "status": "WAITING_FOR_USER_INPUT", "missing_data_question": last_state[ "missing_data_question" ], } ), } ) # Nie kompilujemy finalnego dokumentu, graf został zapauzowany. return # Po przejściu całego grafu — build finalnego dokumentu final_md = DocumentBuilder.build_markdown( sections_plan=last_state.get("sections_plan", []), generated_sections=last_state.get("generated_sections", {}), document_type=last_state.get("document_type", document_type), traceability_data=last_state.get("traceability_data", {}) ) try: from core.subscription.db import SessionLocal from core.projects.models import Project from datetime import datetime, timezone db = SessionLocal() project = db.query(Project).filter(Project.id == project_id).first() if project: project.final_document_markdown = final_md project.final_document_generated_at = datetime.now(timezone.utc) project.updated_at = datetime.now(timezone.utc) db.commit() db.close() except Exception as db_err: logger.warning( f"Zapis final_document do DB nieudany (niekrytyczny): {db_err}" ) await broadcast( { "event": "document_done", "data": json.dumps({"full_content": final_md}), } ) except asyncio.CancelledError: logger.warning(f"Agent Task {project_id} zostal wymuszony anulowaniem.") except Exception as e: error_msg = str(e) if "list index out of range" in error_msg.lower(): error_msg = "Wystąpił błąd synchronizacji planu sekcji. Spróbuj wygenerować dokument ponownie lub zresetuj projekt." logger.error(f"Błąd strumienia generatora: {e}", exc_info=True) for q in SUBSCRIBERS.get(project_id, []): try: await q.put( { "event": "error", "data": json.dumps({"detail": error_msg}), } ) except Exception: pass finally: for q in SUBSCRIBERS.get(project_id, []): try: await q.put(None) except Exception: pass RUNNING_TASKS.pop(project_id, None) @router.get("/stream") async def generate_document_stream( project_id: str, document_type: str = "Wniosek FENG", resume: bool = False, token: str = Query(default=None, alias="token"), ): """ SSE stream (Server-Sent Events) zlecający i podpinający się pod agenta. Odklejony do asyncio.create_task(), odporny na zamykanie zakładki. """ user_id = "anonymous" if token: try: import jwt if token == "dev_test_token": user_id = "test_dev_user" else: decoded = jwt.decode(token, options={"verify_signature": False}) user_id = decoded.get("sub", "anonymous") except Exception: pass namespace = f"tenant_{user_id}" if project_id not in RUNNING_TASKS: RUNNING_TASKS[project_id] = asyncio.create_task( run_generator_task(project_id, namespace, document_type, resume) ) queue = asyncio.Queue() SUBSCRIBERS[project_id].append(queue) async def event_publisher(): try: while True: msg = await queue.get() if msg is None: # Flaga zakończenia break yield msg except asyncio.CancelledError: logger.info("Klient SSE odłączył się, tło agenta nadal działa.") finally: if queue in SUBSCRIBERS.get(project_id, []): SUBSCRIBERS[project_id].remove(queue) return EventSourceResponse(event_publisher()) class ResumeRequest(BaseModel): project_id: str user_response: str @router.post("/resume") async def resume_generation(req: ResumeRequest): """ Przyjmuje odpowiedź od użytkownika (HIL) i aktualizuje stan grafu. """ from agents.generator_agent import DocumentGeneratorAgent agent = DocumentGeneratorAgent() agent.provide_human_response(req.project_id, req.user_response) return {"status": "ok", "message": "Zaktualizowano stan."}