poster / main.py
Hammad712's picture
Update main.py
17b0395 verified
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from io import BytesIO
import base64
# Import the required packages for the language model and prompts
from langchain_groq import ChatGroq
from langchain_core.prompts import PromptTemplate
# Import packages for image processing and Google GenAI client
from google import genai
from google.genai import types
from PIL import Image
# --- Global Initialization ---
# Get API keys from environment variables
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if not LANGCHAIN_API_KEY or not GOOGLE_API_KEY:
raise EnvironmentError("API keys for LANGCHAIN_API_KEY and GOOGLE_API_KEY must be set as environment variables.")
def get_llm(api_key: str):
"""
Returns the ChatGroq LLM instance.
"""
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=1,
max_tokens=1024,
api_key=api_key
)
return llm
# Initialize the LLM instance
llm = get_llm(LANGCHAIN_API_KEY)
# Define the prompt template for enhancing raw prompts.
prompt_template = PromptTemplate.from_template('''
You are a Prompt Enhancement AI Assistant. Your task is to take the user's raw poster image prompt and convert it into a detailed, professional prompt optimized for generating high-quality AI poster.
Enhance the prompt by including relevant details such as:
- Camera specifications (e.g., lens type, aperture, focal length)
- Lighting setup (e.g., natural light, studio lighting, soft shadows)
- Camera angle (e.g., top-down, macro, isometric, side view)
- Background style (e.g., plain white, minimalistic, outdoor, studio backdrop)
- Scene composition (e.g., centered product, depth of field, reflections)
- Call-to-Action: include a clear, concise CTA in the poster text (e.g., "Shop Now", "Learn More")
Focus only on *Poster photography* — do not include humans or models.
Raw Prompt:
{Raw_Prompt}
Enhanced Prompt:
''')
# Initialize the Google GenAI client.
client1 = genai.Client(api_key=GOOGLE_API_KEY)
# --- Request Models ---
class EnhancePromptRequest(BaseModel):
raw_prompt: str
class GenerateImageRequest(BaseModel):
# If both are provided, enhanced_prompt takes priority.
raw_prompt: str = None
enhanced_prompt: str = None
class UpdateImageRequest(BaseModel):
text_instruction: str
# Image encoded in base64
image_base64: str
# --- FastAPI Initialization ---
app = FastAPI(title="Image Generation & Update API")
# 1. Root endpoint
@app.get("/")
async def root():
return {"message": "Welcome to the Image Generation API!"}
# 2. Enhance Prompt endpoint
@app.post("/enhance-prompt")
async def enhance_prompt(request: EnhancePromptRequest):
try:
# Prepare the prompt using the template.
formatted_prompt = prompt_template.invoke({"Raw_Prompt": request.raw_prompt})
# Call the LLM to enhance the prompt.
response = llm.invoke(formatted_prompt)
# Assume the enhanced prompt is in the response.content.
enhanced_prompt = response.content
return {"raw_prompt": request.raw_prompt, "enhanced_prompt": enhanced_prompt}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Helper function to generate image content using GenAI client.
def generate_image_from_prompt(image_prompt: str):
response = client1.models.generate_content(
model="gemini-2.0-flash-exp-image-generation",
contents=image_prompt,
config=types.GenerateContentConfig(
response_modalities=['Text', 'Image']
)
)
# Prepare a dict to hold results.
result = {"text": None, "image_base64": None}
# Process the response parts.
for part in response.candidates[0].content.parts:
if part.text is not None:
result["text"] = part.text
elif part.inline_data is not None:
# Open the image from bytes.
image = Image.open(BytesIO(part.inline_data.data))
# Convert image to bytes.
buffered = BytesIO()
image.save(buffered, format="PNG")
# Encode to base64.
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
result["image_base64"] = img_str
return result
# 3. Generate Image endpoint
@app.post("/generate-poster")
async def generate_image(request: GenerateImageRequest):
try:
# Decide which prompt to use.
if request.enhanced_prompt:
image_prompt = request.enhanced_prompt
elif request.raw_prompt:
image_prompt = request.raw_prompt
else:
raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.")
result = generate_image_from_prompt(image_prompt)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 4. Update Image endpoint
@app.post("/update-poster")
async def update_image(request: UpdateImageRequest):
try:
# Decode the incoming base64 image.
image_bytes = base64.b64decode(request.image_base64)
image = Image.open(BytesIO(image_bytes))
# The contents for the update API include the text instruction and the current image.
response = client1.models.generate_content(
model="gemini-2.0-flash-exp-image-generation",
contents=[request.text_instruction, image],
config=types.GenerateContentConfig(
response_modalities=['Text', 'Image']
)
)
updated_result = {"text": None, "updated_image_base64": None}
for part in response.candidates[0].content.parts:
if part.text is not None:
updated_result["text"] = part.text
elif part.inline_data is not None:
updated_image = Image.open(BytesIO(part.inline_data.data))
buffered = BytesIO()
updated_image.save(buffered, format="PNG")
updated_img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
updated_result["updated_image_base64"] = updated_img_str
return updated_result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# --- Run the app ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)