Spaces:
Sleeping
Sleeping
| from pydantic_ai import Agent, RunContext | |
| from pydantic_ai.models.openai import OpenAIModel | |
| from dotenv import load_dotenv | |
| import os | |
| import asyncio | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import logfire | |
| from src.services.generate_mask import GenerateMaskService | |
| from src.hopter.client import ( | |
| Hopter, | |
| Environment, | |
| MagicReplaceInput, | |
| SuperResolutionInput, | |
| ) | |
| from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image | |
| import base64 | |
| import tempfile | |
| load_dotenv() | |
| logfire.configure(token=os.environ.get("LOGFIRE_TOKEN")) | |
| logfire.instrument_openai() | |
| system_prompt = """ | |
| I will give you an editing instruction of the image. | |
| if the edit instruction involved modifying parts of the image, please generate a mask for it. | |
| if images are not provided, ask the user to provide an image. | |
| """ | |
| class ImageEditDeps: | |
| edit_instruction: str | |
| hopter_client: Hopter | |
| mask_service: GenerateMaskService | |
| image_url: Optional[str] = None | |
| model = OpenAIModel( | |
| "gpt-4o", | |
| api_key=os.environ.get("OPENAI_API_KEY"), | |
| ) | |
| class EditImageResult: | |
| edited_image_url: str | |
| image_edit_agent = Agent(model, system_prompt=system_prompt, deps_type=ImageEditDeps) | |
| def upload_image_from_base64(base64_image: str) -> str: | |
| image_format = base64_image.split(",")[0] | |
| image_data = base64.b64decode(base64_image.split(",")[1]) | |
| suffix = ".jpg" if image_format == "image/jpeg" else ".png" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
| temp_filename = temp_file.name | |
| with open(temp_filename, "wb") as f: | |
| f.write(image_data) | |
| return upload_image(temp_filename) | |
| async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult: | |
| """ | |
| Use this tool to edit an object in the image. for example: | |
| - remove the pole | |
| - replace the dog with a cat | |
| - change the background to a beach | |
| - remove the person in the image | |
| - change the hair color to red | |
| - change the hat to a cap | |
| """ | |
| edit_instruction = ctx.deps.edit_instruction | |
| image_url = ctx.deps.image_url | |
| mask_service = ctx.deps.mask_service | |
| hopter_client = ctx.deps.hopter_client | |
| image_uri = download_image_to_data_uri(image_url) | |
| # Generate mask | |
| mask_instruction = mask_service.get_mask_generation_instruction( | |
| edit_instruction, image_url | |
| ) | |
| mask = mask_service.generate_mask(mask_instruction, image_uri) | |
| # Magic replace | |
| input = MagicReplaceInput( | |
| image=image_uri, mask=mask, prompt=mask_instruction.target_caption | |
| ) | |
| result = hopter_client.magic_replace(input) | |
| uploaded_image = upload_image_from_base64(result.base64_image) | |
| return EditImageResult(edited_image_url=uploaded_image) | |
| async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult: | |
| """ | |
| run super resolution, upscale, or enhance the image | |
| """ | |
| image_url = ctx.deps.image_url | |
| hopter_client = ctx.deps.hopter_client | |
| image_uri = download_image_to_data_uri(image_url) | |
| input = SuperResolutionInput( | |
| image_b64=image_uri, scale=4, use_face_enhancement=False | |
| ) | |
| result = hopter_client.super_resolution(input) | |
| uploaded_image = upload_image_from_base64(result.scaled_image) | |
| return EditImageResult(edited_image_url=uploaded_image) | |
| async def main(): | |
| image_file_path = "./assets/lakeview.jpg" | |
| image_url = image_path_to_uri(image_file_path) | |
| prompt = "remove the light post" | |
| messages = [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": image_url}}, | |
| ] | |
| # Initialize services | |
| hopter = Hopter( | |
| api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING | |
| ) | |
| mask_service = GenerateMaskService(hopter=hopter) | |
| # Initialize dependencies | |
| deps = ImageEditDeps( | |
| edit_instruction=prompt, | |
| image_url=image_url, | |
| hopter_client=hopter, | |
| mask_service=mask_service, | |
| ) | |
| async with image_edit_agent.run_stream(messages, deps=deps) as result: | |
| async for message in result.stream(): | |
| print(message) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |