Spaces:
Sleeping
Sleeping
| import React, { createContext, useContext, useReducer, ReactNode } from 'react'; | |
| import type { | |
| GenerationState, | |
| GenerationStep, | |
| VideoProvider, | |
| VeoSegment, | |
| GeneratedVideo | |
| } from '@/types'; | |
| // Initial state | |
| const initialState: GenerationState = { | |
| step: 'idle', | |
| provider: null, | |
| segments: [], | |
| currentSegmentIndex: 0, | |
| generatedVideos: [], | |
| progress: { | |
| current: 0, | |
| total: 0, | |
| message: '', | |
| }, | |
| error: null, | |
| taskId: null, | |
| retryState: null, | |
| activeTaskIds: [] as string[], | |
| isCancelling: false, | |
| }; | |
| // Action types | |
| type GenerationAction = | |
| | { type: 'SET_PROVIDER'; payload: VideoProvider } | |
| | { type: 'SET_STEP'; payload: GenerationStep } | |
| | { type: 'SET_SEGMENTS'; payload: VeoSegment[] } | |
| | { type: 'SET_CURRENT_SEGMENT'; payload: number } | |
| | { type: 'ADD_GENERATED_VIDEO'; payload: GeneratedVideo } | |
| | { type: 'SET_PROGRESS'; payload: { current?: number; total?: number; message?: string } } | |
| | { type: 'SET_ERROR'; payload: string | null } | |
| | { type: 'SET_TASK_ID'; payload: string | null } | |
| | { type: 'SET_RETRY_STATE'; payload: { failedSegmentIndex: number; error: string } | null } | |
| | { type: 'ADD_TASK_ID'; payload: string } | |
| | { type: 'REMOVE_TASK_ID'; payload: string } | |
| | { type: 'SET_CANCELLING'; payload: boolean } | |
| | { type: 'RESET' }; | |
| // Reducer | |
| function generationReducer(state: GenerationState, action: GenerationAction): GenerationState { | |
| switch (action.type) { | |
| case 'SET_PROVIDER': | |
| return { ...state, provider: action.payload }; | |
| case 'SET_STEP': | |
| return { ...state, step: action.payload }; | |
| case 'SET_SEGMENTS': | |
| return { ...state, segments: action.payload }; | |
| case 'SET_CURRENT_SEGMENT': | |
| return { ...state, currentSegmentIndex: action.payload }; | |
| case 'ADD_GENERATED_VIDEO': | |
| return { | |
| ...state, | |
| generatedVideos: [...state.generatedVideos, action.payload] | |
| }; | |
| case 'SET_PROGRESS': | |
| return { | |
| ...state, | |
| progress: { ...state.progress, ...action.payload } | |
| }; | |
| case 'SET_ERROR': | |
| return { ...state, error: action.payload, step: action.payload ? 'error' : state.step }; | |
| case 'SET_TASK_ID': | |
| return { ...state, taskId: action.payload }; | |
| case 'SET_RETRY_STATE': | |
| return { ...state, retryState: action.payload }; | |
| case 'ADD_TASK_ID': | |
| return { ...state, activeTaskIds: [...state.activeTaskIds, action.payload] }; | |
| case 'REMOVE_TASK_ID': | |
| return { ...state, activeTaskIds: state.activeTaskIds.filter(id => id !== action.payload) }; | |
| case 'SET_CANCELLING': | |
| return { ...state, isCancelling: action.payload }; | |
| case 'RESET': | |
| return { ...initialState, provider: state.provider }; | |
| default: | |
| return state; | |
| } | |
| } | |
| // Context | |
| interface GenerationContextValue { | |
| state: GenerationState; | |
| dispatch: React.Dispatch<GenerationAction>; | |
| // Helper actions | |
| selectProvider: (provider: VideoProvider) => void; | |
| setStep: (step: GenerationStep) => void; | |
| startGeneration: (segments: VeoSegment[]) => void; | |
| advanceSegment: () => void; | |
| addVideo: (video: GeneratedVideo) => void; | |
| updateProgress: (message: string, current?: number, total?: number) => void; | |
| setError: (error: string | null) => void; | |
| setRetryState: (state: { failedSegmentIndex: number; error: string } | null) => void; | |
| updateSegments: (segments: VeoSegment[]) => void; | |
| addTaskId: (taskId: string) => void; | |
| removeTaskId: (taskId: string) => void; | |
| cancelGeneration: () => Promise<void>; | |
| reset: () => void; | |
| } | |
| const GenerationContext = createContext<GenerationContextValue | null>(null); | |
| // Provider component | |
| export function GenerationProvider({ children }: { children: ReactNode }) { | |
| const [state, dispatch] = useReducer(generationReducer, initialState); | |
| const value: GenerationContextValue = { | |
| state, | |
| dispatch, | |
| selectProvider: (provider) => { | |
| dispatch({ type: 'SET_PROVIDER', payload: provider }); | |
| dispatch({ type: 'SET_STEP', payload: 'configuring' }); | |
| }, | |
| setStep: (step) => { | |
| dispatch({ type: 'SET_STEP', payload: step }); | |
| }, | |
| startGeneration: (segments) => { | |
| dispatch({ type: 'SET_SEGMENTS', payload: segments }); | |
| dispatch({ type: 'SET_CURRENT_SEGMENT', payload: 0 }); | |
| dispatch({ type: 'SET_PROGRESS', payload: { current: 0, total: segments.length, message: 'Starting generation...' } }); | |
| dispatch({ type: 'SET_STEP', payload: 'generating_video' }); | |
| }, | |
| advanceSegment: () => { | |
| const nextIndex = state.currentSegmentIndex + 1; | |
| dispatch({ type: 'SET_CURRENT_SEGMENT', payload: nextIndex }); | |
| dispatch({ type: 'SET_PROGRESS', payload: { current: nextIndex } }); | |
| }, | |
| addVideo: (video) => { | |
| dispatch({ type: 'ADD_GENERATED_VIDEO', payload: video }); | |
| }, | |
| updateProgress: (message, current, total) => { | |
| dispatch({ | |
| type: 'SET_PROGRESS', | |
| payload: { | |
| message, | |
| ...(current !== undefined && { current }), | |
| ...(total !== undefined && { total }) | |
| } | |
| }); | |
| }, | |
| setError: (error) => { | |
| dispatch({ type: 'SET_ERROR', payload: error }); | |
| }, | |
| setRetryState: (retryState) => { | |
| dispatch({ type: 'SET_RETRY_STATE', payload: retryState }); | |
| }, | |
| updateSegments: (segments) => { | |
| dispatch({ type: 'SET_SEGMENTS', payload: segments }); | |
| }, | |
| addTaskId: (taskId) => { | |
| dispatch({ type: 'ADD_TASK_ID', payload: taskId }); | |
| }, | |
| removeTaskId: (taskId) => { | |
| dispatch({ type: 'REMOVE_TASK_ID', payload: taskId }); | |
| }, | |
| cancelGeneration: async () => { | |
| dispatch({ type: 'SET_CANCELLING', payload: true }); | |
| try { | |
| const { klingCancel } = await import('@/utils/api'); | |
| // Cancel all active tasks | |
| const currentTaskIds = [...state.activeTaskIds]; | |
| const cancelPromises = currentTaskIds.map(taskId => | |
| klingCancel(taskId).catch(err => { | |
| console.warn(`Failed to cancel task ${taskId}:`, err); | |
| }) | |
| ); | |
| await Promise.all(cancelPromises); | |
| // Clear all task IDs | |
| currentTaskIds.forEach(id => { | |
| dispatch({ type: 'REMOVE_TASK_ID', payload: id }); | |
| }); | |
| dispatch({ type: 'SET_TASK_ID', payload: null }); | |
| dispatch({ type: 'SET_ERROR', payload: 'Generation cancelled by user' }); | |
| dispatch({ type: 'SET_STEP', payload: 'error' }); | |
| } catch (error) { | |
| console.error('Error cancelling generation:', error); | |
| dispatch({ type: 'SET_ERROR', payload: 'Failed to cancel generation' }); | |
| } finally { | |
| dispatch({ type: 'SET_CANCELLING', payload: false }); | |
| } | |
| }, | |
| reset: () => { | |
| dispatch({ type: 'RESET' }); | |
| }, | |
| }; | |
| return ( | |
| <GenerationContext.Provider value={value}> | |
| {children} | |
| </GenerationContext.Provider> | |
| ); | |
| } | |
| // Hook | |
| export function useGeneration() { | |
| const context = useContext(GenerationContext); | |
| if (!context) { | |
| throw new Error('useGeneration must be used within GenerationProvider'); | |
| } | |
| return context; | |
| } | |