import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from io import BytesIO import base64 # Import the required packages for the language model and prompts from langchain_groq import ChatGroq from langchain_core.prompts import PromptTemplate # Import packages for image processing and Google GenAI client from google import genai from google.genai import types from PIL import Image # --- Global Initialization --- # Get API keys from environment variables LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY") GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") if not LANGCHAIN_API_KEY or not GOOGLE_API_KEY: raise EnvironmentError("API keys for LANGCHAIN_API_KEY and GOOGLE_API_KEY must be set as environment variables.") def get_llm(api_key: str): """ Returns the ChatGroq LLM instance. """ llm = ChatGroq( model="meta-llama/llama-4-scout-17b-16e-instruct", temperature=1, max_tokens=1024, api_key=api_key ) return llm # Initialize the LLM instance llm = get_llm(LANGCHAIN_API_KEY) # Define the prompt template for enhancing raw prompts. prompt_template = PromptTemplate.from_template(''' You are a Prompt Enhancement AI Assistant. Your task is to take the user's raw poster image prompt and convert it into a detailed, professional prompt optimized for generating high-quality AI poster. Enhance the prompt by including relevant details such as: - Camera specifications (e.g., lens type, aperture, focal length) - Lighting setup (e.g., natural light, studio lighting, soft shadows) - Camera angle (e.g., top-down, macro, isometric, side view) - Background style (e.g., plain white, minimalistic, outdoor, studio backdrop) - Scene composition (e.g., centered product, depth of field, reflections) - Call-to-Action: include a clear, concise CTA in the poster text (e.g., "Shop Now", "Learn More") Focus only on *Poster photography* — do not include humans or models. Raw Prompt: {Raw_Prompt} Enhanced Prompt: ''') # Initialize the Google GenAI client. client1 = genai.Client(api_key=GOOGLE_API_KEY) # --- Request Models --- class EnhancePromptRequest(BaseModel): raw_prompt: str class GenerateImageRequest(BaseModel): # If both are provided, enhanced_prompt takes priority. raw_prompt: str = None enhanced_prompt: str = None class UpdateImageRequest(BaseModel): text_instruction: str # Image encoded in base64 image_base64: str # --- FastAPI Initialization --- app = FastAPI(title="Image Generation & Update API") # 1. Root endpoint @app.get("/") async def root(): return {"message": "Welcome to the Image Generation API!"} # 2. Enhance Prompt endpoint @app.post("/enhance-prompt") async def enhance_prompt(request: EnhancePromptRequest): try: # Prepare the prompt using the template. formatted_prompt = prompt_template.invoke({"Raw_Prompt": request.raw_prompt}) # Call the LLM to enhance the prompt. response = llm.invoke(formatted_prompt) # Assume the enhanced prompt is in the response.content. enhanced_prompt = response.content return {"raw_prompt": request.raw_prompt, "enhanced_prompt": enhanced_prompt} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Helper function to generate image content using GenAI client. def generate_image_from_prompt(image_prompt: str): response = client1.models.generate_content( model="gemini-2.0-flash-exp-image-generation", contents=image_prompt, config=types.GenerateContentConfig( response_modalities=['Text', 'Image'] ) ) # Prepare a dict to hold results. result = {"text": None, "image_base64": None} # Process the response parts. for part in response.candidates[0].content.parts: if part.text is not None: result["text"] = part.text elif part.inline_data is not None: # Open the image from bytes. image = Image.open(BytesIO(part.inline_data.data)) # Convert image to bytes. buffered = BytesIO() image.save(buffered, format="PNG") # Encode to base64. img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") result["image_base64"] = img_str return result # 3. Generate Image endpoint @app.post("/generate-poster") async def generate_image(request: GenerateImageRequest): try: # Decide which prompt to use. if request.enhanced_prompt: image_prompt = request.enhanced_prompt elif request.raw_prompt: image_prompt = request.raw_prompt else: raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.") result = generate_image_from_prompt(image_prompt) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # 4. Update Image endpoint @app.post("/update-poster") async def update_image(request: UpdateImageRequest): try: # Decode the incoming base64 image. image_bytes = base64.b64decode(request.image_base64) image = Image.open(BytesIO(image_bytes)) # The contents for the update API include the text instruction and the current image. response = client1.models.generate_content( model="gemini-2.0-flash-exp-image-generation", contents=[request.text_instruction, image], config=types.GenerateContentConfig( response_modalities=['Text', 'Image'] ) ) updated_result = {"text": None, "updated_image_base64": None} for part in response.candidates[0].content.parts: if part.text is not None: updated_result["text"] = part.text elif part.inline_data is not None: updated_image = Image.open(BytesIO(part.inline_data.data)) buffered = BytesIO() updated_image.save(buffered, format="PNG") updated_img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") updated_result["updated_image_base64"] = updated_img_str return updated_result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # --- Run the app --- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)