Spaces:
Running
Running
| from typing import List | |
| from fastapi import APIRouter, HTTPException | |
| from loguru import logger | |
| from api.models import ( | |
| DefaultPromptResponse, | |
| DefaultPromptUpdate, | |
| TransformationCreate, | |
| TransformationExecuteRequest, | |
| TransformationExecuteResponse, | |
| TransformationResponse, | |
| TransformationUpdate, | |
| ) | |
| from open_notebook.domain.models import Model | |
| from open_notebook.domain.transformation import DefaultPrompts, Transformation | |
| from open_notebook.exceptions import InvalidInputError | |
| from open_notebook.graphs.transformation import graph as transformation_graph | |
| router = APIRouter() | |
| async def get_transformations(): | |
| """Get all transformations.""" | |
| try: | |
| transformations = await Transformation.get_all(order_by="name asc") | |
| return [ | |
| TransformationResponse( | |
| id=transformation.id or "", | |
| name=transformation.name, | |
| title=transformation.title, | |
| description=transformation.description, | |
| prompt=transformation.prompt, | |
| apply_default=transformation.apply_default, | |
| created=str(transformation.created), | |
| updated=str(transformation.updated), | |
| ) | |
| for transformation in transformations | |
| ] | |
| except Exception as e: | |
| logger.error(f"Error fetching transformations: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error fetching transformations: {str(e)}" | |
| ) | |
| async def create_transformation(transformation_data: TransformationCreate): | |
| """Create a new transformation.""" | |
| try: | |
| new_transformation = Transformation( | |
| name=transformation_data.name, | |
| title=transformation_data.title, | |
| description=transformation_data.description, | |
| prompt=transformation_data.prompt, | |
| apply_default=transformation_data.apply_default, | |
| ) | |
| await new_transformation.save() | |
| return TransformationResponse( | |
| id=new_transformation.id or "", | |
| name=new_transformation.name, | |
| title=new_transformation.title, | |
| description=new_transformation.description, | |
| prompt=new_transformation.prompt, | |
| apply_default=new_transformation.apply_default, | |
| created=str(new_transformation.created), | |
| updated=str(new_transformation.updated), | |
| ) | |
| except InvalidInputError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Error creating transformation: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error creating transformation: {str(e)}" | |
| ) | |
| async def execute_transformation(execute_request: TransformationExecuteRequest): | |
| """Execute a transformation on input text.""" | |
| try: | |
| # Validate transformation exists | |
| transformation = await Transformation.get(execute_request.transformation_id) | |
| if not transformation: | |
| raise HTTPException(status_code=404, detail="Transformation not found") | |
| # Validate model exists | |
| model = await Model.get(execute_request.model_id) | |
| if not model: | |
| raise HTTPException(status_code=404, detail="Model not found") | |
| # Execute the transformation | |
| result = await transformation_graph.ainvoke( | |
| dict( # type: ignore[arg-type] | |
| input_text=execute_request.input_text, | |
| transformation=transformation, | |
| ), | |
| config=dict(configurable={"model_id": execute_request.model_id}), | |
| ) | |
| return TransformationExecuteResponse( | |
| output=result["output"], | |
| transformation_id=execute_request.transformation_id, | |
| model_id=execute_request.model_id, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error executing transformation: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error executing transformation: {str(e)}" | |
| ) | |
| async def get_default_prompt(): | |
| """Get the default transformation prompt.""" | |
| try: | |
| default_prompts: DefaultPrompts = await DefaultPrompts.get_instance() # type: ignore[assignment] | |
| return DefaultPromptResponse( | |
| transformation_instructions=default_prompts.transformation_instructions or "" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error fetching default prompt: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error fetching default prompt: {str(e)}" | |
| ) | |
| async def update_default_prompt(prompt_update: DefaultPromptUpdate): | |
| """Update the default transformation prompt.""" | |
| try: | |
| default_prompts: DefaultPrompts = await DefaultPrompts.get_instance() # type: ignore[assignment] | |
| default_prompts.transformation_instructions = prompt_update.transformation_instructions | |
| await default_prompts.update() | |
| return DefaultPromptResponse( | |
| transformation_instructions=default_prompts.transformation_instructions | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error updating default prompt: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error updating default prompt: {str(e)}" | |
| ) | |
| async def get_transformation(transformation_id: str): | |
| """Get a specific transformation by ID.""" | |
| try: | |
| transformation = await Transformation.get(transformation_id) | |
| if not transformation: | |
| raise HTTPException(status_code=404, detail="Transformation not found") | |
| return TransformationResponse( | |
| id=transformation.id or "", | |
| name=transformation.name, | |
| title=transformation.title, | |
| description=transformation.description, | |
| prompt=transformation.prompt, | |
| apply_default=transformation.apply_default, | |
| created=str(transformation.created), | |
| updated=str(transformation.updated), | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error fetching transformation {transformation_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error fetching transformation: {str(e)}" | |
| ) | |
| async def update_transformation( | |
| transformation_id: str, transformation_update: TransformationUpdate | |
| ): | |
| """Update a transformation.""" | |
| try: | |
| transformation = await Transformation.get(transformation_id) | |
| if not transformation: | |
| raise HTTPException(status_code=404, detail="Transformation not found") | |
| # Update only provided fields | |
| if transformation_update.name is not None: | |
| transformation.name = transformation_update.name | |
| if transformation_update.title is not None: | |
| transformation.title = transformation_update.title | |
| if transformation_update.description is not None: | |
| transformation.description = transformation_update.description | |
| if transformation_update.prompt is not None: | |
| transformation.prompt = transformation_update.prompt | |
| if transformation_update.apply_default is not None: | |
| transformation.apply_default = transformation_update.apply_default | |
| await transformation.save() | |
| return TransformationResponse( | |
| id=transformation.id or "", | |
| name=transformation.name, | |
| title=transformation.title, | |
| description=transformation.description, | |
| prompt=transformation.prompt, | |
| apply_default=transformation.apply_default, | |
| created=str(transformation.created), | |
| updated=str(transformation.updated), | |
| ) | |
| except HTTPException: | |
| raise | |
| except InvalidInputError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Error updating transformation {transformation_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error updating transformation: {str(e)}" | |
| ) | |
| async def delete_transformation(transformation_id: str): | |
| """Delete a transformation.""" | |
| try: | |
| transformation = await Transformation.get(transformation_id) | |
| if not transformation: | |
| raise HTTPException(status_code=404, detail="Transformation not found") | |
| await transformation.delete() | |
| return {"message": "Transformation deleted successfully"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error deleting transformation {transformation_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error deleting transformation: {str(e)}" | |
| ) | |