File size: 6,474 Bytes
ec29232 17b0395 b813a2a ec29232 17b0395 7797f37 ec29232 17b0395 ec29232 17b0395 ec29232 17b0395 ec29232 17b0395 ec29232 17b0395 ec29232 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from io import BytesIO
import base64
# Import the required packages for the language model and prompts
from langchain_groq import ChatGroq
from langchain_core.prompts import PromptTemplate
# Import packages for image processing and Google GenAI client
from google import genai
from google.genai import types
from PIL import Image
# --- Global Initialization ---
# Get API keys from environment variables
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if not LANGCHAIN_API_KEY or not GOOGLE_API_KEY:
raise EnvironmentError("API keys for LANGCHAIN_API_KEY and GOOGLE_API_KEY must be set as environment variables.")
def get_llm(api_key: str):
"""
Returns the ChatGroq LLM instance.
"""
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=1,
max_tokens=1024,
api_key=api_key
)
return llm
# Initialize the LLM instance
llm = get_llm(LANGCHAIN_API_KEY)
# Define the prompt template for enhancing raw prompts.
prompt_template = PromptTemplate.from_template('''
You are a Prompt Enhancement AI Assistant. Your task is to take the user's raw poster image prompt and convert it into a detailed, professional prompt optimized for generating high-quality AI poster.
Enhance the prompt by including relevant details such as:
- Camera specifications (e.g., lens type, aperture, focal length)
- Lighting setup (e.g., natural light, studio lighting, soft shadows)
- Camera angle (e.g., top-down, macro, isometric, side view)
- Background style (e.g., plain white, minimalistic, outdoor, studio backdrop)
- Scene composition (e.g., centered product, depth of field, reflections)
- Call-to-Action: include a clear, concise CTA in the poster text (e.g., "Shop Now", "Learn More")
Focus only on *Poster photography* — do not include humans or models.
Raw Prompt:
{Raw_Prompt}
Enhanced Prompt:
''')
# Initialize the Google GenAI client.
client1 = genai.Client(api_key=GOOGLE_API_KEY)
# --- Request Models ---
class EnhancePromptRequest(BaseModel):
raw_prompt: str
class GenerateImageRequest(BaseModel):
# If both are provided, enhanced_prompt takes priority.
raw_prompt: str = None
enhanced_prompt: str = None
class UpdateImageRequest(BaseModel):
text_instruction: str
# Image encoded in base64
image_base64: str
# --- FastAPI Initialization ---
app = FastAPI(title="Image Generation & Update API")
# 1. Root endpoint
@app.get("/")
async def root():
return {"message": "Welcome to the Image Generation API!"}
# 2. Enhance Prompt endpoint
@app.post("/enhance-prompt")
async def enhance_prompt(request: EnhancePromptRequest):
try:
# Prepare the prompt using the template.
formatted_prompt = prompt_template.invoke({"Raw_Prompt": request.raw_prompt})
# Call the LLM to enhance the prompt.
response = llm.invoke(formatted_prompt)
# Assume the enhanced prompt is in the response.content.
enhanced_prompt = response.content
return {"raw_prompt": request.raw_prompt, "enhanced_prompt": enhanced_prompt}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Helper function to generate image content using GenAI client.
def generate_image_from_prompt(image_prompt: str):
response = client1.models.generate_content(
model="gemini-2.0-flash-exp-image-generation",
contents=image_prompt,
config=types.GenerateContentConfig(
response_modalities=['Text', 'Image']
)
)
# Prepare a dict to hold results.
result = {"text": None, "image_base64": None}
# Process the response parts.
for part in response.candidates[0].content.parts:
if part.text is not None:
result["text"] = part.text
elif part.inline_data is not None:
# Open the image from bytes.
image = Image.open(BytesIO(part.inline_data.data))
# Convert image to bytes.
buffered = BytesIO()
image.save(buffered, format="PNG")
# Encode to base64.
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
result["image_base64"] = img_str
return result
# 3. Generate Image endpoint
@app.post("/generate-poster")
async def generate_image(request: GenerateImageRequest):
try:
# Decide which prompt to use.
if request.enhanced_prompt:
image_prompt = request.enhanced_prompt
elif request.raw_prompt:
image_prompt = request.raw_prompt
else:
raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.")
result = generate_image_from_prompt(image_prompt)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 4. Update Image endpoint
@app.post("/update-poster")
async def update_image(request: UpdateImageRequest):
try:
# Decode the incoming base64 image.
image_bytes = base64.b64decode(request.image_base64)
image = Image.open(BytesIO(image_bytes))
# The contents for the update API include the text instruction and the current image.
response = client1.models.generate_content(
model="gemini-2.0-flash-exp-image-generation",
contents=[request.text_instruction, image],
config=types.GenerateContentConfig(
response_modalities=['Text', 'Image']
)
)
updated_result = {"text": None, "updated_image_base64": None}
for part in response.candidates[0].content.parts:
if part.text is not None:
updated_result["text"] = part.text
elif part.inline_data is not None:
updated_image = Image.open(BytesIO(part.inline_data.data))
buffered = BytesIO()
updated_image.save(buffered, format="PNG")
updated_img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
updated_result["updated_image_base64"] = updated_img_str
return updated_result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# --- Run the app ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|