Spaces:
Running
Running
| import sys | |
| import os | |
| import json | |
| import logging | |
| import asyncio | |
| from collections import defaultdict | |
| from fastapi import APIRouter, Query | |
| from pydantic import BaseModel | |
| from sse_starlette.sse import EventSourceResponse | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter() | |
| from typing import Dict, List, Any | |
| import asyncio | |
| RUNNING_TASKS: Dict[str, asyncio.Task[Any]] = {} | |
| SUBSCRIBERS: Dict[str, List[asyncio.Queue[Any]]] = defaultdict(list) | |
| async def run_generator_task( | |
| project_id: str, namespace: str, document_type: str, resume: bool = False | |
| ): | |
| logger.info(f"Rozpoczynam tlo generatora dla projektu {project_id}") | |
| try: | |
| from agents.generator_agent import DocumentGeneratorAgent, GeneratorState | |
| from core.document_builder import DocumentBuilder | |
| agent = DocumentGeneratorAgent() | |
| project_description = "" | |
| db_plan = [] | |
| company_context = "" | |
| try: | |
| from core.subscription.db import SessionLocal | |
| from core.projects.models import ( | |
| Project, | |
| ProjectSection, | |
| ProjectSectionTemplate, | |
| ) | |
| _db = SessionLocal() | |
| _proj = _db.query(Project).filter(Project.id == project_id).first() | |
| if _proj: | |
| parts = [] | |
| if hasattr(_proj, "description") and _proj.description: | |
| parts.append(_proj.description) | |
| if hasattr(_proj, "nip") and _proj.nip: | |
| parts.append(f"NIP: {_proj.nip}") | |
| # Fetch company name from attributes or external_context | |
| c_name = getattr(_proj, "company_name", None) | |
| if ( | |
| not c_name | |
| and _proj.external_context | |
| and "company_data" in _proj.external_context | |
| ): | |
| c_name = _proj.external_context["company_data"].get("name") | |
| if c_name: | |
| parts.append( | |
| f"DANE OBOWIĄZKOWE: Wnioskodawcą jest firma: {c_name}. Bezwzględnie używaj tej nazwy firmy (lub jej zanonimizowanego tokenu) pisząc sekcje wniosku, aby uniknąć bezosobowego tonu." | |
| ) | |
| if _proj.external_context: | |
| import json | |
| parts.append("DODATKOWY KONTEKST ZEWNĘTRZNY (np. dane GUS, wpisane przez użytkownika informacje):") | |
| parts.append(json.dumps(_proj.external_context, ensure_ascii=False, indent=2)) | |
| project_description = "\n".join(parts) | |
| db_sections = ( | |
| _db.query(ProjectSection) | |
| .filter(ProjectSection.project_id == project_id) | |
| .order_by(ProjectSection.order) | |
| .all() | |
| ) | |
| if db_sections: | |
| templates = ( | |
| _db.query(ProjectSectionTemplate) | |
| .filter( | |
| ProjectSectionTemplate.program_type == _proj.program_type | |
| ) | |
| .all() | |
| ) | |
| from endpoints.projects import UNIVERSAL_FALLBACK_MAP | |
| tmpl_map = {t.section_type: t.title for t in templates} | |
| for sec in db_sections: | |
| title = tmpl_map.get(sec.section_type) | |
| if not title: | |
| title = UNIVERSAL_FALLBACK_MAP.get(sec.section_type, sec.section_type) | |
| db_plan.append({"type": sec.section_type, "title": title}) | |
| if _proj.external_context and "company_data" in _proj.external_context: | |
| company_context = json.dumps(_proj.external_context["company_data"], ensure_ascii=False, indent=2) | |
| _db.close() | |
| except Exception as desc_err: | |
| logger.debug( | |
| f"[Generator] Nie udało się wczytać opisu projektu: {desc_err}" | |
| ) | |
| initial_state: GeneratorState = { | |
| "project_id": project_id, | |
| "namespace": namespace, | |
| "document_type": document_type, | |
| "project_description": project_description, | |
| "sections_plan": db_plan, | |
| "current_section_idx": 0, | |
| "generated_sections": {}, | |
| "context": "", | |
| "is_completed": False, | |
| "additional_context": f"BEZWZGLĘDNE ŹRÓDŁO PRAWDY O FIRMIE (DANE Z GUS/KRS):\n{company_context}\nZakaz wymyślania innych danych o firmie!" if company_context else "", | |
| "missing_data_question": "", | |
| "traceability_data": {} | |
| } | |
| last_state = dict(initial_state) | |
| async def broadcast(msg: dict): | |
| for q in SUBSCRIBERS.get(project_id, []): | |
| try: | |
| await q.put(msg) | |
| except Exception: | |
| pass | |
| async for event in agent.astm_stream( | |
| initial_state, thread_id=project_id, resume=resume | |
| ): | |
| kind = event.get("event", "") | |
| # Powiadomienie o początku generowania sekcji | |
| if kind == "on_chain_start" and event.get("name") == "draft_section": | |
| plan = last_state.get("sections_plan", []) | |
| idx = last_state.get("current_section_idx", 0) | |
| if plan and idx < len(plan): | |
| section = plan[idx] | |
| section_title = ( | |
| section["title"] if isinstance(section, dict) else section | |
| ) | |
| await broadcast( | |
| { | |
| "event": "section_started", | |
| "data": section_title, | |
| } | |
| ) | |
| # Powiadomienie o ukończeniu sekcji | |
| elif kind == "on_chain_end" and event.get("name") == "draft_section": | |
| output = event.get("data", {}).get("output", {}) | |
| if "generated_sections" in output: | |
| last_state.update(output) | |
| completed_idx = output.get("current_section_idx", 1) - 1 | |
| plan = last_state.get("sections_plan", []) | |
| if completed_idx >= 0 and completed_idx < len(plan): | |
| section = plan[completed_idx] | |
| section_title = ( | |
| section["title"] if isinstance(section, dict) else section | |
| ) | |
| section_content = output["generated_sections"].get( | |
| section_title, "" | |
| ) | |
| sec_type = ( | |
| section["type"] | |
| if isinstance(section, dict) | |
| else section_title.lower().replace(" ", "_") | |
| ) | |
| # ZAPIS CZASTKOWY | |
| try: | |
| from core.subscription.db import SessionLocal | |
| from core.projects.models import ProjectSection | |
| tz_db = SessionLocal() | |
| db_sec = ( | |
| tz_db.query(ProjectSection) | |
| .filter( | |
| ProjectSection.project_id == project_id, | |
| ProjectSection.section_type == sec_type, | |
| ) | |
| .first() | |
| ) | |
| if db_sec: | |
| db_sec.content = section_content | |
| db_sec.generated_by_ai = True | |
| db_sec.is_approved = False | |
| tz_db.commit() | |
| tz_db.close() | |
| except Exception as e: | |
| logger.error(f"[Generator] partial save failed: {e}") | |
| await broadcast( | |
| { | |
| "event": "section_completed", | |
| "data": json.dumps( | |
| { | |
| "title": section_title, | |
| "content": section_content, | |
| "index": completed_idx + 1, | |
| } | |
| ), | |
| } | |
| ) | |
| # Śledzenie zmian stanu planu | |
| elif kind == "on_chain_end" and event.get("name") == "plan_document": | |
| output = event.get("data", {}).get("output", {}) | |
| last_state.update(output) | |
| # Sprawdzenie czy graf nie został zapauzowany (HIL) | |
| if last_state.get("missing_data_question"): | |
| logger.info( | |
| f"Graf zatrzymany. Oczekiwanie na dane dla projektu {project_id}." | |
| ) | |
| await broadcast( | |
| { | |
| "event": "waiting_for_user_input", | |
| "data": json.dumps( | |
| { | |
| "status": "WAITING_FOR_USER_INPUT", | |
| "missing_data_question": last_state[ | |
| "missing_data_question" | |
| ], | |
| } | |
| ), | |
| } | |
| ) | |
| # Nie kompilujemy finalnego dokumentu, graf został zapauzowany. | |
| return | |
| # Po przejściu całego grafu — build finalnego dokumentu | |
| final_md = DocumentBuilder.build_markdown( | |
| sections_plan=last_state.get("sections_plan", []), | |
| generated_sections=last_state.get("generated_sections", {}), | |
| document_type=last_state.get("document_type", document_type), | |
| traceability_data=last_state.get("traceability_data", {}) | |
| ) | |
| try: | |
| from core.subscription.db import SessionLocal | |
| from core.projects.models import Project | |
| from datetime import datetime, timezone | |
| db = SessionLocal() | |
| project = db.query(Project).filter(Project.id == project_id).first() | |
| if project: | |
| project.final_document_markdown = final_md | |
| project.final_document_generated_at = datetime.now(timezone.utc) | |
| project.updated_at = datetime.now(timezone.utc) | |
| db.commit() | |
| db.close() | |
| except Exception as db_err: | |
| logger.warning( | |
| f"Zapis final_document do DB nieudany (niekrytyczny): {db_err}" | |
| ) | |
| await broadcast( | |
| { | |
| "event": "document_done", | |
| "data": json.dumps({"full_content": final_md}), | |
| } | |
| ) | |
| except asyncio.CancelledError: | |
| logger.warning(f"Agent Task {project_id} zostal wymuszony anulowaniem.") | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "list index out of range" in error_msg.lower(): | |
| error_msg = "Wystąpił błąd synchronizacji planu sekcji. Spróbuj wygenerować dokument ponownie lub zresetuj projekt." | |
| logger.error(f"Błąd strumienia generatora: {e}", exc_info=True) | |
| for q in SUBSCRIBERS.get(project_id, []): | |
| try: | |
| await q.put( | |
| { | |
| "event": "error", | |
| "data": json.dumps({"detail": error_msg}), | |
| } | |
| ) | |
| except Exception: | |
| pass | |
| finally: | |
| for q in SUBSCRIBERS.get(project_id, []): | |
| try: | |
| await q.put(None) | |
| except Exception: | |
| pass | |
| RUNNING_TASKS.pop(project_id, None) | |
| async def generate_document_stream( | |
| project_id: str, | |
| document_type: str = "Wniosek FENG", | |
| resume: bool = False, | |
| token: str = Query(default=None, alias="token"), | |
| ): | |
| """ | |
| SSE stream (Server-Sent Events) zlecający i podpinający się pod agenta. | |
| Odklejony do asyncio.create_task(), odporny na zamykanie zakładki. | |
| """ | |
| user_id = "anonymous" | |
| if token: | |
| try: | |
| import jwt | |
| if token == "dev_test_token": | |
| user_id = "test_dev_user" | |
| else: | |
| decoded = jwt.decode(token, options={"verify_signature": False}) | |
| user_id = decoded.get("sub", "anonymous") | |
| except Exception: | |
| pass | |
| namespace = f"tenant_{user_id}" | |
| if project_id not in RUNNING_TASKS: | |
| RUNNING_TASKS[project_id] = asyncio.create_task( | |
| run_generator_task(project_id, namespace, document_type, resume) | |
| ) | |
| queue = asyncio.Queue() | |
| SUBSCRIBERS[project_id].append(queue) | |
| async def event_publisher(): | |
| try: | |
| while True: | |
| msg = await queue.get() | |
| if msg is None: # Flaga zakończenia | |
| break | |
| yield msg | |
| except asyncio.CancelledError: | |
| logger.info("Klient SSE odłączył się, tło agenta nadal działa.") | |
| finally: | |
| if queue in SUBSCRIBERS.get(project_id, []): | |
| SUBSCRIBERS[project_id].remove(queue) | |
| return EventSourceResponse(event_publisher()) | |
| class ResumeRequest(BaseModel): | |
| project_id: str | |
| user_response: str | |
| async def resume_generation(req: ResumeRequest): | |
| """ | |
| Przyjmuje odpowiedź od użytkownika (HIL) i aktualizuje stan grafu. | |
| """ | |
| from agents.generator_agent import DocumentGeneratorAgent | |
| agent = DocumentGeneratorAgent() | |
| agent.provide_human_response(req.project_id, req.user_response) | |
| return {"status": "ok", "message": "Zaktualizowano stan."} | |