Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import os | |
| from dotenv import load_dotenv | |
| from typing import Optional, List, Dict | |
| import json | |
| from pydantic import BaseModel | |
| # Import project modules | |
| from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps | |
| from src.agents.generic_agent import generic_agent | |
| from src.hopter.client import Hopter, Environment | |
| from src.services.generate_mask import GenerateMaskService | |
| from src.utils import upload_image | |
| # Load environment variables | |
| load_dotenv() | |
| app = FastAPI(title="Image Edit API") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| class EditRequest(BaseModel): | |
| edit_instruction: str | |
| image_url: Optional[str] = None | |
| class MessageContent(BaseModel): | |
| type: str | |
| text: Optional[str] = None | |
| image_url: Optional[Dict[str, str]] = None | |
| class Message(BaseModel): | |
| content: List[MessageContent] | |
| async def test(query: str): | |
| async def stream_messages(): | |
| async with generic_agent.run_stream(query) as result: | |
| async for message in result.stream(debounce_by=0.01): | |
| yield json.dumps(message) + "\n" | |
| return StreamingResponse(stream_messages(), media_type="text/plain") | |
| async def edit_image(request: EditRequest): | |
| """ | |
| Edit an image based on the provided instruction. | |
| Returns the URL of the edited image. | |
| """ | |
| try: | |
| # Initialize services | |
| hopter = Hopter( | |
| api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING | |
| ) | |
| mask_service = GenerateMaskService(hopter=hopter) | |
| # Initialize dependencies | |
| deps = ImageEditDeps( | |
| edit_instruction=request.edit_instruction, | |
| image_url=request.image_url, | |
| hopter_client=hopter, | |
| mask_service=mask_service, | |
| ) | |
| # Create messages | |
| messages = [{"type": "text", "text": request.edit_instruction}] | |
| if request.image_url: | |
| messages.append( | |
| {"type": "image_url", "image_url": {"url": request.image_url}} | |
| ) | |
| # Run the agent | |
| result = await image_edit_agent.run(messages, deps=deps) | |
| # Return the result | |
| return {"edited_image_url": result.edited_image_url} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def edit_image_stream(request: EditRequest): | |
| """ | |
| Edit an image based on the provided instruction. | |
| Streams the agent's responses back to the client. | |
| """ | |
| try: | |
| # Initialize services | |
| hopter = Hopter( | |
| api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING | |
| ) | |
| mask_service = GenerateMaskService(hopter=hopter) | |
| # Initialize dependencies | |
| deps = ImageEditDeps( | |
| edit_instruction=request.edit_instruction, | |
| image_url=request.image_url, | |
| hopter_client=hopter, | |
| mask_service=mask_service, | |
| ) | |
| # Create messages | |
| messages = [{"type": "text", "text": request.edit_instruction}] | |
| if request.image_url: | |
| messages.append( | |
| {"type": "image_url", "image_url": {"url": request.image_url}} | |
| ) | |
| async def stream_generator(): | |
| async with image_edit_agent.run_stream(messages, deps=deps) as result: | |
| async for message in result.stream(): | |
| # Convert message to JSON and yield | |
| yield json.dumps(message) + "\n" | |
| return StreamingResponse(stream_generator(), media_type="application/x-ndjson") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_image_file(file: UploadFile = File(...)): | |
| """ | |
| Upload an image file and return its URL. | |
| """ | |
| try: | |
| # Save the uploaded file to a temporary location | |
| temp_file_path = f"/tmp/{file.filename}" | |
| with open(temp_file_path, "wb") as buffer: | |
| buffer.write(await file.read()) | |
| # Upload the image to Google Cloud Storage | |
| image_url = upload_image(temp_file_path) | |
| # Remove the temporary file | |
| os.remove(temp_file_path) | |
| return {"image_url": image_url} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| """ | |
| Health check endpoint. | |
| """ | |
| return {"status": "ok"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |