Spaces:
Sleeping
Sleeping
| import { GoogleGenAI, GenerateContentResponse } from "@google/genai"; | |
| import { VlaData, TaskSegment, Interaction } from '../types'; | |
| import { GET_OVERALL_GOAL_PROMPT, GET_TASKS_AND_INTERACTIONS_PROMPT } from './prompts'; | |
| // Get API key from environment variables - try both names | |
| const getApiKey = () => { | |
| // In development, try process.env (from Vite) | |
| if (typeof process !== 'undefined' && process.env) { | |
| return process.env.GEMINI_API_KEY || process.env.API_KEY || ''; | |
| } | |
| // In production, try window.__ENV__ (injected at runtime) | |
| if (typeof window !== 'undefined' && (window as any).__ENV__) { | |
| return (window as any).__ENV__.GEMINI_API_KEY || (window as any).__ENV__.API_KEY || ''; | |
| } | |
| // Fallback - prompt user for API key | |
| const storedKey = localStorage.getItem('gemini_api_key'); | |
| if (storedKey) return storedKey; | |
| const userKey = prompt('Please enter your Gemini API key:'); | |
| if (userKey) { | |
| localStorage.setItem('gemini_api_key', userKey); | |
| return userKey; | |
| } | |
| return ''; | |
| }; | |
| const apiKey = getApiKey(); | |
| if (!apiKey) { | |
| throw new Error("API key not found. Please set GEMINI_API_KEY or API_KEY environment variable, or enter it when prompted."); | |
| } | |
| const ai = new GoogleGenAI({ apiKey }); | |
| const CHUNK_SIZE = 15; | |
| const OVERLAP = 5; | |
| /** | |
| * A wrapper for the Gemini API call that includes retry and fallback logic. | |
| * It first tries the primary model. If it encounters rate-limiting errors, | |
| * it retries, and if that still fails, it falls back to a "lite" configuration. | |
| */ | |
| async function callGeminiWithRetry( | |
| params: Parameters<typeof ai.models.generateContent>[0], | |
| maxRetries: number = 2, // Retries for primary model before fallback | |
| initialDelay: number = 1000 | |
| ): Promise<{ response: GenerateContentResponse; usedFallback: boolean }> { | |
| let lastRateLimitError: any; | |
| for (let i = 0; i < maxRetries; i++) { | |
| try { | |
| const primaryParams = { ...params, model: 'gemini-2.5-flash-preview-04-17' }; | |
| const response = await ai.models.generateContent(primaryParams); | |
| return { response, usedFallback: false }; | |
| } catch (error) { | |
| const isRateLimitError = error?.toString().includes('429') || error?.toString().includes('RESOURCE_EXHAUSTED'); | |
| if (isRateLimitError) { | |
| lastRateLimitError = error; | |
| const delay = initialDelay * Math.pow(2, i); | |
| console.warn(`Primary model failed with rate limit on attempt ${i + 1}. Retrying in ${delay}ms...`); | |
| await new Promise(resolve => setTimeout(resolve, delay)); | |
| } else { | |
| console.error("Gemini API call failed with a non-rate-limit error.", error); | |
| throw error; // Fail fast on other errors | |
| } | |
| } | |
| } | |
| // If loop completes, all primary attempts were rate-limited. | |
| console.warn(`All primary model attempts failed due to rate limits. Switching to fallback 'lite' configuration.`); | |
| try { | |
| const fallbackParams = { | |
| ...params, | |
| model: 'gemini-2.5-flash-preview-04-17', | |
| config: { | |
| ...(params.config || {}), | |
| thinkingConfig: { thinkingBudget: 0 } | |
| } | |
| }; | |
| const response = await ai.models.generateContent(fallbackParams); | |
| console.log("Successfully generated content with fallback configuration."); | |
| return { response, usedFallback: true }; | |
| } catch (fallbackError) { | |
| console.error("Fallback configuration also failed.", fallbackError); | |
| // Throw the fallback error as it's the most recent. | |
| throw fallbackError; | |
| } | |
| } | |
| function parseJsonResponse<T>(response: GenerateContentResponse): T { | |
| let jsonStr = response.text.trim(); | |
| const fenceRegex = /^```(\w*)?\s*\n?(.*?)\n?\s*```$/s; | |
| const match = jsonStr.match(fenceRegex); | |
| if (match && match[2]) { | |
| jsonStr = match[2].trim(); | |
| } | |
| try { | |
| return JSON.parse(jsonStr) as T; | |
| } catch (e) { | |
| console.error("Failed to parse JSON response:", jsonStr); | |
| throw new Error("AI response was not valid JSON."); | |
| } | |
| } | |
| /** | |
| * Generates the overall goal by analyzing a few keyframes. | |
| */ | |
| export async function generateOverallGoal(keyframes: string[]): Promise<{ goal: string, usedFallback: boolean }> { | |
| const imageParts = keyframes.map(frame => ({ | |
| inlineData: { mimeType: 'image/jpeg', data: frame.split(',')[1] }, | |
| })); | |
| const contents = [{ text: GET_OVERALL_GOAL_PROMPT }, ...imageParts]; | |
| try { | |
| const { response, usedFallback } = await callGeminiWithRetry({ | |
| model: 'gemini-2.5-flash-preview-04-17', | |
| contents: { parts: contents }, | |
| config: { responseMimeType: "application/json", temperature: 0.1 } | |
| }); | |
| const parsed = parseJsonResponse<{ overallGoal: string }>(response); | |
| if (!parsed.overallGoal) { | |
| throw new Error("AI response for overall goal is missing the 'overallGoal' field."); | |
| } | |
| return { goal: parsed.overallGoal, usedFallback }; | |
| } catch (error) { | |
| console.error("Error calling Gemini API for overall goal:", error); | |
| throw new Error(`AI model failed to determine overall goal: ${error instanceof Error ? error.message : 'Unknown error'}`); | |
| } | |
| } | |
| function mergeAndDeduplicateTasks(tasks: Omit<TaskSegment, 'id'>[]): TaskSegment[] { | |
| if (tasks.length === 0) return []; | |
| // Sort tasks by their start frame | |
| const sortedTasks = tasks.sort((a, b) => a.startFrame - b.startFrame); | |
| const mergedTasks: Omit<TaskSegment, 'id'>[] = []; | |
| if (sortedTasks.length > 0) { | |
| // Deep copy to avoid mutation issues when merging | |
| mergedTasks.push(JSON.parse(JSON.stringify(sortedTasks[0]))); | |
| } else { | |
| return []; | |
| } | |
| for (let i = 1; i < sortedTasks.length; i++) { | |
| const currentTask = sortedTasks[i]; | |
| const lastMergedTask = mergedTasks[mergedTasks.length - 1]; | |
| const isSimilarDescription = currentTask.description.trim().toLowerCase() === lastMergedTask.description.trim().toLowerCase(); | |
| const isOverlapping = currentTask.startFrame < lastMergedTask.endFrame; | |
| // If descriptions are identical and frames overlap, merge them. | |
| if (isSimilarDescription && isOverlapping) { | |
| lastMergedTask.endFrame = Math.max(lastMergedTask.endFrame, currentTask.endFrame); | |
| // Merge interactions and deduplicate them | |
| if (currentTask.interactions) { | |
| lastMergedTask.interactions = lastMergedTask.interactions || []; | |
| const existingInteractionKeys = new Set( | |
| lastMergedTask.interactions.map(inter => `${inter.frameIndex}-${inter.type}`) | |
| ); | |
| for (const newInteraction of currentTask.interactions) { | |
| const newKey = `${newInteraction.frameIndex}-${newInteraction.type}`; | |
| if (!existingInteractionKeys.has(newKey)) { | |
| lastMergedTask.interactions.push(newInteraction); | |
| existingInteractionKeys.add(newKey); | |
| } | |
| } | |
| // Sort interactions by frame index after merging | |
| lastMergedTask.interactions.sort((a,b) => a.frameIndex - b.frameIndex); | |
| } | |
| } else if (currentTask.startFrame >= lastMergedTask.endFrame) { | |
| // Only add new tasks that start after or at the same time the previous one ends | |
| mergedTasks.push(JSON.parse(JSON.stringify(currentTask))); | |
| } | |
| } | |
| // Re-assign sequential IDs and ensure interactions array exists | |
| return mergedTasks.map((task, index) => ({ | |
| ...task, | |
| id: index + 1, | |
| interactions: task.interactions || [], | |
| })); | |
| } | |
| /** | |
| * Generates task segments and their interactions by analyzing the video in overlapping chunks. | |
| */ | |
| export async function generateTasksAndInteractions(frames: string[]): Promise<{ tasks: TaskSegment[], usedFallback: boolean }> { | |
| const chunks: { frames: string[], startIndex: number }[] = []; | |
| for (let i = 0; i < frames.length; i += CHUNK_SIZE - OVERLAP) { | |
| const chunkFrames = frames.slice(i, i + CHUNK_SIZE); | |
| if (chunkFrames.length > 0) { | |
| chunks.push({ frames: chunkFrames, startIndex: i }); | |
| } | |
| } | |
| try { | |
| let anyChunkUsedFallback = false; | |
| const chunkPromises = chunks.map(async (chunk) => { | |
| const imageParts = chunk.frames.map(frame => ({ | |
| inlineData: { mimeType: 'image/jpeg', data: frame.split(',')[1] }, | |
| })); | |
| const endFrameIndex = chunk.startIndex + chunk.frames.length - 1; | |
| const prompt = GET_TASKS_AND_INTERACTIONS_PROMPT(chunk.startIndex, endFrameIndex); | |
| const contents = [{ text: prompt }, ...imageParts]; | |
| const { response, usedFallback } = await callGeminiWithRetry({ | |
| model: 'gemini-2.5-flash-preview-04-17', | |
| contents: { parts: contents }, | |
| config: { responseMimeType: "application/json", temperature: 0.1 } | |
| }); | |
| if (usedFallback) { | |
| anyChunkUsedFallback = true; | |
| } | |
| // AI returns tasks with interactions included | |
| const parsed = parseJsonResponse<Omit<TaskSegment, 'id'>[]>(response); | |
| if (!Array.isArray(parsed)) { | |
| console.warn("AI response for a chunk was not an array, skipping chunk.", parsed); | |
| return []; | |
| } | |
| return parsed; | |
| }); | |
| const resultsFromAllChunks = await Promise.all(chunkPromises); | |
| const allTasks = resultsFromAllChunks.flat(); | |
| // Sort, merge, and de-duplicate tasks and their interactions | |
| return { | |
| tasks: mergeAndDeduplicateTasks(allTasks), | |
| usedFallback: anyChunkUsedFallback | |
| }; | |
| } catch (error) { | |
| console.error("Error calling Gemini API for task and interaction generation:", error); | |
| throw new Error(`AI model failed during analysis: ${error instanceof Error ? error.message : 'Unknown error'}`); | |
| } | |
| } | |