Spaces:
Sleeping
Sleeping
| import asyncio | |
| import httpx | |
| from enum import Enum | |
| from src.utils import image_path_to_uri | |
| from dotenv import load_dotenv | |
| import os | |
| from pydantic import BaseModel, Field | |
| from typing import List | |
| load_dotenv() | |
| class Environment(Enum): | |
| STAGING = "staging" | |
| PRODUCTION = "production" | |
| def base_url(self) -> str: | |
| match self: | |
| case Environment.STAGING: | |
| return "https://serving.hopter.staging.picc.co" | |
| case Environment.PRODUCTION: | |
| return "https://serving.hopter.picc.co" | |
| class RamGroundedSamInput(BaseModel): | |
| text_prompt: str = Field( | |
| ..., description="The text prompt for the mask generation." | |
| ) | |
| image_b64: str = Field(..., description="The image in base64 format.") | |
| class RamGroundedSamResult(BaseModel): | |
| mask_b64: str = Field(..., description="The mask image in base64 format.") | |
| class_label: str = Field(..., description="The class label of the mask.") | |
| confidence: float = Field(..., description="The confidence score of the mask.") | |
| bbox: List[float] = Field( | |
| ..., description="The bounding box of the mask in the format [x1, y1, x2, y2]." | |
| ) | |
| class MagicReplaceInput(BaseModel): | |
| image: str = Field(..., description="The image in base64 format.") | |
| mask: str = Field(..., description="The mask in base64 format.") | |
| prompt: str = Field(..., description="The prompt for the magic replace.") | |
| class MagicReplaceResult(BaseModel): | |
| base64_image: str = Field(..., description="The edited image in base64 format.") | |
| class SuperResolutionInput(BaseModel): | |
| image_b64: str = Field(..., description="The image in base64 format.") | |
| scale: int = Field(4, description="The scale of the image to upscale to.") | |
| use_face_enhancement: bool = Field( | |
| False, description="Whether to use face enhancement." | |
| ) | |
| class SuperResolutionResult(BaseModel): | |
| scaled_image: str = Field( | |
| ..., description="The super-resolved image in base64 format." | |
| ) | |
| class Hopter: | |
| def __init__(self, api_key: str, environment: Environment = Environment.PRODUCTION): | |
| self.api_key = api_key | |
| self.base_url = environment.base_url | |
| self.client = httpx.Client() | |
| def generate_mask(self, input: RamGroundedSamInput) -> RamGroundedSamResult: | |
| print(f"Generating mask with input: {input.text_prompt}") | |
| try: | |
| response = self.client.post( | |
| f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions", | |
| headers={ | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={"input": input.model_dump()}, | |
| timeout=None, | |
| ) | |
| response.raise_for_status() # Raise an error for bad responses | |
| instance = response.json().get("output").get("instances")[0] | |
| print("Generated mask.") | |
| return RamGroundedSamResult(**instance) | |
| except httpx.HTTPStatusError as exc: | |
| print( | |
| f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}" | |
| ) | |
| except Exception as exc: | |
| print(f"An unexpected error occurred: {exc}") | |
| def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult: | |
| print(f"Magic replacing with input: {input.prompt}") | |
| try: | |
| response = self.client.post( | |
| f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions", | |
| headers={ | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={"input": input.model_dump()}, | |
| timeout=None, | |
| ) | |
| response.raise_for_status() # Raise an error for bad responses | |
| instance = response.json().get("output") | |
| print("Magic replaced.") | |
| return MagicReplaceResult(**instance) | |
| except httpx.HTTPStatusError as exc: | |
| print( | |
| f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}" | |
| ) | |
| except Exception as exc: | |
| print(f"An unexpected error occurred: {exc}") | |
| def super_resolution(self, input: SuperResolutionInput) -> SuperResolutionResult: | |
| try: | |
| response = self.client.post( | |
| f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions", | |
| headers={ | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={"input": input.model_dump()}, | |
| timeout=None, | |
| ) | |
| response.raise_for_status() # Raise an error for bad responses | |
| instance = response.json().get("output") | |
| print("Super-resolutin done") | |
| return SuperResolutionResult(**instance) | |
| except httpx.HTTPStatusError as exc: | |
| print( | |
| f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}" | |
| ) | |
| except Exception as exc: | |
| print(f"An unexpected error occurred: {exc}") | |
| async def test_generate_mask(hopter: Hopter, image_url: str) -> str: | |
| input = RamGroundedSamInput(text_prompt="pole", image_b64=image_url) | |
| mask = hopter.generate_mask(input) | |
| return mask.mask_b64 | |
| async def test_magic_replace( | |
| hopter: Hopter, image_url: str, mask: str, prompt: str | |
| ) -> str: | |
| input = MagicReplaceInput(image=image_url, mask=mask, prompt=prompt) | |
| result = hopter.magic_replace(input) | |
| return result.base64_image | |
| async def main(): | |
| hopter = Hopter( | |
| api_key=os.getenv("HOPTER_API_KEY"), environment=Environment.STAGING | |
| ) | |
| image_file_path = "./assets/lakeview.jpg" | |
| image_url = image_path_to_uri(image_file_path) | |
| mask = await test_generate_mask(hopter, image_url) | |
| magic_replace_result = await test_magic_replace( | |
| hopter, image_url, mask, "remove the pole" | |
| ) | |
| print(magic_replace_result) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |