| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field |
| from typing import Optional, Dict, Any |
| import torch |
| from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering |
| from PIL import Image |
| import io |
| import base64 |
|
|
| app = FastAPI( |
| title="OmniParser API", |
| description="API for parsing GUI elements from images", |
| version="1.0.0" |
| ) |
|
|
| |
| class OmniParser: |
| def __init__(self): |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| self.processor = AutoProcessor.from_pretrained( |
| "microsoft/Florence-2-base", |
| trust_remote_code=True, |
| cache_dir="/code/.cache" |
| ) |
| self.model = AutoModelForVisualQuestionAnswering.from_pretrained( |
| "microsoft/OmniParser/icon_caption_florence", |
| trust_remote_code=True, |
| cache_dir="/code/.cache" |
| ).to(self.device) |
|
|
| @torch.inference_mode() |
| def process_image( |
| self, |
| image: Image.Image, |
| question: str = "What elements do you see in this GUI?", |
| ) -> Dict[str, Any]: |
| |
| inputs = self.processor(images=image, text=question, return_tensors="pt").to(self.device) |
| outputs = self.model(**inputs) |
| |
| |
| predicted_answer = self.processor.decode( |
| outputs.logits.argmax(-1)[0], |
| skip_special_tokens=True |
| ) |
| |
| return { |
| "parsed_elements": predicted_answer, |
| "box_coordinates": {} |
| } |
|
|
| |
| model = OmniParser() |
|
|
| |
| class ParseRequest(BaseModel): |
| image_data: str = Field(..., description="Base64 encoded image data") |
| question: Optional[str] = Field( |
| default="What elements do you see in this GUI?", |
| description="Question to ask about the GUI" |
| ) |
|
|
| class ParseResponse(BaseModel): |
| parsed_elements: str |
| box_coordinates: dict |
| output_image: Optional[str] |
|
|
| def load_and_preprocess_image(image_data: bytes) -> Optional[Image.Image]: |
| """Load and preprocess image from bytes data.""" |
| try: |
| image = Image.open(io.BytesIO(image_data)) |
| return image |
| except Exception as e: |
| raise ValueError(f"Failed to load image: {str(e)}") |
|
|
| def encode_output_image(image: Image.Image) -> str: |
| """Encode PIL Image to base64 string.""" |
| buffered = io.BytesIO() |
| image.save(buffered, format="PNG") |
| return base64.b64encode(buffered.getvalue()).decode() |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "message": "OmniParser API is running", |
| "docs_url": "/docs" |
| } |
|
|
| @app.post("/parse", response_model=ParseResponse) |
| async def parse_image(request: ParseRequest): |
| try: |
| |
| image_bytes = base64.b64decode(request.image_data) |
| image = load_and_preprocess_image(image_bytes) |
| |
| |
| result = model.process_image( |
| image=image, |
| question=request.question |
| ) |
| |
| |
| return ParseResponse( |
| parsed_elements=result["parsed_elements"], |
| box_coordinates=result["box_coordinates"], |
| output_image=encode_output_image(image) |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |