Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import io | |
| import logging | |
| import asyncio | |
| from typing import List | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| import os | |
| load_dotenv() | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Sequential Test Step Generator API") | |
| # Add CORS middleware to allow frontend requests | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class ModelInference: | |
| def __init__(self): | |
| """Initialize the model.""" | |
| genai.configure(api_key=os.getenv("KEY")) | |
| self.model = genai.GenerativeModel("gemini-2.5-flash") | |
| self.device = "cuda" | |
| logger.info("Model loaded successfully!") | |
| def process_single_image(self, image: Image.Image) -> Image.Image: | |
| """Convert image to RGB PIL Image.""" | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| return image | |
| def predict_next_step_with_history( | |
| self, image: Image.Image, goal: str, completed_steps: List[str] = None | |
| ) -> str: | |
| """Predict the next step.""" | |
| try: | |
| if completed_steps is None: | |
| completed_steps = [] | |
| image = self.process_single_image(image) | |
| if completed_steps: | |
| history_str = "\n".join( | |
| [f"{i + 1}. {step}" for i, step in enumerate(completed_steps)] | |
| ) | |
| prompt = f"""Analyze this UI and generate the next test step. | |
| Task: {goal} | |
| Completed: | |
| {history_str} | |
| Output format: "ACTION: description [x1, y1, x2, y2]" | |
| Actions: CLICK, TYPE, SCROLL, WAIT, VERIFY, SELECT, DRAG | |
| Coordinates: normalized 0.0-1.0 | |
| Next step only:""" | |
| else: | |
| prompt = f"""Analyze this UI and generate the first test step. | |
| Task: {goal} | |
| Output format: "ACTION: description [x1, y1, x2, y2]" | |
| Actions: CLICK, TYPE, SCROLL, WAIT, VERIFY, SELECT, DRAG | |
| Coordinates: normalized 0.0-1.0 | |
| First step only:""" | |
| logger.info( | |
| f"Generating prediction with {len(completed_steps)} history steps" | |
| ) | |
| response = self.model.generate_content([prompt, image]) | |
| prediction = response.text.strip() | |
| logger.info(f"Generated prediction: {prediction}") | |
| return prediction | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {str(e)}") | |
| raise | |
| def generate_step_sequence( | |
| self, | |
| image: Image.Image, | |
| task_description: str, | |
| action_history: str = "", | |
| max_steps: int = 10, | |
| ) -> List[str]: | |
| """Generate sequence of steps.""" | |
| logger.info("Using recursive history-aware workflow generation") | |
| return self.generate_recursive_workflow( | |
| image=image, | |
| goal=task_description, | |
| initial_history=action_history, | |
| max_steps=max_steps, | |
| ) | |
| def generate_recursive_workflow( | |
| self, | |
| image: Image.Image, | |
| goal: str, | |
| initial_history: str = "", | |
| max_steps: int = 10, | |
| ) -> List[str]: | |
| """Generate all workflow steps at once (faster).""" | |
| completed_steps = [] | |
| if initial_history and initial_history.strip(): | |
| if "β" in initial_history: | |
| completed_steps = [ | |
| s.strip() for s in initial_history.split("β") if s.strip() | |
| ] | |
| elif "," in initial_history: | |
| completed_steps = [ | |
| s.strip() for s in initial_history.split(",") if s.strip() | |
| ] | |
| else: | |
| completed_steps = [initial_history.strip()] | |
| logger.info(f"Generating all workflow steps at once for goal: {goal}") | |
| logger.info(f"Initial history: {completed_steps}") | |
| # Generate all steps in one call | |
| image = self.process_single_image(image) | |
| if completed_steps: | |
| history_str = "\n".join( | |
| [f"{i + 1}. {step}" for i, step in enumerate(completed_steps)] | |
| ) | |
| prompt = f"""Analyze this UI and generate ALL remaining test steps to complete the task. | |
| Task: {goal} | |
| Already completed steps: | |
| {history_str} | |
| Generate the REMAINING steps needed to complete the task. | |
| CRITICAL RULES: | |
| - Output ONLY the steps, NO explanations, NO reasoning, NO extra text | |
| - One step per line | |
| - Format: "ACTION: description [x1, y1, x2, y2]" | |
| - Actions: CLICK, TYPE, SCROLL, WAIT, VERIFY, SELECT, DRAG | |
| - Coordinates: normalized 0.0-1.0 | |
| - For TYPE actions, describe what to type WITHOUT providing example values (e.g., "TYPE: Enter username in email field" NOT "TYPE: test@example.com") | |
| - For CLICK actions, describe what to click (e.g., "CLICK: Click on the username input field") | |
| - Maximum {max_steps} steps | |
| Steps:""" | |
| else: | |
| prompt = f"""Analyze this UI and generate ALL test steps to complete the task. | |
| Task: {goal} | |
| Generate a complete sequence of steps to accomplish this task. | |
| CRITICAL RULES: | |
| - Output ONLY the steps, NO explanations, NO reasoning, NO extra text | |
| - One step per line | |
| - Format: "ACTION: description [x1, y1, x2, y2]" | |
| - Actions: CLICK, TYPE, SCROLL, WAIT, VERIFY, SELECT, DRAG | |
| - Coordinates: normalized 0.0-1.0 | |
| - For TYPE actions, describe what to type WITHOUT providing example values (e.g., "TYPE: Enter username in email field" NOT "TYPE: test@example.com") | |
| - For CLICK actions, describe what to click (e.g., "CLICK: Click on the username input field") | |
| - Maximum {max_steps} steps | |
| Steps:""" | |
| try: | |
| logger.info("Generating all steps in single API call...") | |
| response = self.model.generate_content([prompt, image]) | |
| all_steps_text = response.text.strip() | |
| # Parse steps (split by newlines) | |
| new_steps = [] | |
| for line in all_steps_text.split("\n"): | |
| line = line.strip() | |
| # Skip empty lines, numbered prefixes, and explanatory text | |
| if not line: | |
| continue | |
| # Remove numbering if present (e.g., "1. " or "1) ") | |
| if line and line[0].isdigit(): | |
| line = line.split(".", 1)[-1].strip() | |
| line = line.split(")", 1)[-1].strip() | |
| # Only keep lines that start with action keywords | |
| if any( | |
| line.upper().startswith(action) | |
| for action in [ | |
| "CLICK:", | |
| "TYPE:", | |
| "SCROLL:", | |
| "WAIT:", | |
| "VERIFY:", | |
| "SELECT:", | |
| "DRAG:", | |
| ] | |
| ): | |
| new_steps.append(line) | |
| logger.info(f"Generated {len(new_steps)} steps in one call") | |
| for i, step in enumerate(new_steps): | |
| logger.info(f"Step {len(completed_steps) + i + 1}: {step}") | |
| return completed_steps + new_steps | |
| except Exception as e: | |
| logger.error(f"Error generating all steps: {str(e)}") | |
| raise | |
| # Initialize model | |
| logger.info("Initializing model inference...") | |
| model_inference = ModelInference() | |
| logger.info("Model inference ready!") | |
| async def root(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "running", | |
| "message": "Sequential Test Step Generator API", | |
| "device": str(model_inference.device), | |
| "model_loaded": True, | |
| } | |
| async def predict( | |
| image: UploadFile = File(..., description="UI screenshot image"), | |
| action_history: str = Form(default="", description="Previous action history"), | |
| task_description: str = Form(..., description="Task description"), | |
| generate_sequence: bool = Form( | |
| default=True, description="Generate full sequence or single action" | |
| ), | |
| ): | |
| """Generate test steps based on UI image, action history, and task description.""" | |
| try: | |
| await asyncio.sleep(0.5) | |
| image_data = await image.read() | |
| pil_image = Image.open(io.BytesIO(image_data)) | |
| logger.info(f"Received image: {pil_image.size}, mode: {pil_image.mode}") | |
| logger.info(f"Task description: {task_description}") | |
| logger.info( | |
| f"Action history: {action_history[:100]}..." | |
| if action_history | |
| else "No history" | |
| ) | |
| if generate_sequence: | |
| predicted_steps = model_inference.generate_step_sequence( | |
| image=pil_image, | |
| task_description=task_description, | |
| action_history=action_history, | |
| max_steps=10, | |
| ) | |
| else: | |
| completed_steps = [] | |
| if action_history and action_history.strip(): | |
| if "β" in action_history: | |
| completed_steps = [ | |
| s.strip() for s in action_history.split("β") if s.strip() | |
| ] | |
| elif "," in action_history: | |
| completed_steps = [ | |
| s.strip() for s in action_history.split(",") if s.strip() | |
| ] | |
| else: | |
| completed_steps = [action_history.strip()] | |
| predicted_action = model_inference.predict_next_step_with_history( | |
| image=pil_image, goal=task_description, completed_steps=completed_steps | |
| ) | |
| predicted_steps = [predicted_action] | |
| return { | |
| "success": True, | |
| "steps": predicted_steps, | |
| "image_size": pil_image.size, | |
| "num_steps": len(predicted_steps), | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing request: {str(e)}", exc_info=True) | |
| return {"success": False, "error": "ERROR", "steps": []} | |
| async def health(): | |
| """Detailed health check.""" | |
| return { | |
| "status": "healthy", | |
| "device": str(model_inference.device), | |
| "model_loaded": model_inference.model is not None, | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |