| import os |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from io import BytesIO |
| import base64 |
|
|
| |
| from langchain_groq import ChatGroq |
| from langchain_core.prompts import PromptTemplate |
|
|
| |
| from google import genai |
| from google.genai import types |
| from PIL import Image |
|
|
| |
|
|
| |
| LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY") |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
| if not LANGCHAIN_API_KEY or not GOOGLE_API_KEY: |
| raise EnvironmentError("API keys for LANGCHAIN_API_KEY and GOOGLE_API_KEY must be set as environment variables.") |
|
|
| def get_llm(api_key: str): |
| """ |
| Returns the ChatGroq LLM instance. |
| """ |
| llm = ChatGroq( |
| model="meta-llama/llama-4-scout-17b-16e-instruct", |
| temperature=1, |
| max_tokens=1024, |
| api_key=api_key |
| ) |
| return llm |
|
|
| |
| llm = get_llm(LANGCHAIN_API_KEY) |
|
|
| |
| prompt_template = PromptTemplate.from_template(''' |
| You are a Prompt Enhancement AI Assistant specialized in crafting marketing posters that effectively communicate a brand's message and capture audience attention. Using the provided business description and the user's raw poster concept, generate a detailed, professional prompt optimized for AI-driven poster creation. |
| |
| Enhance the prompt by including: |
| - always write generate image, dont use more text |
| - Visual Style: (e.g., flat design, isometric, photorealistic, illustrative) |
| - Color Scheme & Branding: (e.g., vibrant brand colors, pastel accents, monochrome palette) |
| - Typography: (e.g., bold sans-serif headline, elegant serif subtext) |
| - Composition & Layout: (e.g., focal point placement, use of negative space, layered elements) |
| - Imagery Elements: (e.g., product shots, icons, background patterns, textures) |
| - Lighting & Mood: (e.g., bright, dramatic shadows, soft ambient glow) |
| - Contextual Details: (e.g., logo placement, tagline integration, call-to-action placement) |
| - Call-to-Action: include a clear, concise CTA in the poster text (e.g., "Shop Now", "Learn More") |
| |
| Business Description: |
| {Raw_Prompt} |
| |
| Enhanced Poster Prompt: |
| ''') |
|
|
| |
| client1 = genai.Client(api_key=GOOGLE_API_KEY) |
|
|
| |
|
|
| class EnhancePromptRequest(BaseModel): |
| Raw_Prompt: str |
|
|
| class GenerateImageRequest(BaseModel): |
| |
| Raw_Prompt: str = None |
| enhanced_prompt: str = None |
|
|
| class UpdateImageRequest(BaseModel): |
| text_instruction: str |
| |
| image_base64: str |
|
|
| |
|
|
| app = FastAPI(title="Image Generation & Update API") |
|
|
| |
| @app.get("/") |
| async def root(): |
| return {"message": "Welcome to the Image Generation API!"} |
|
|
| |
| @app.post("/enhance-prompt") |
| async def enhance_prompt(request: EnhancePromptRequest): |
| try: |
| |
| formatted_prompt = prompt_template.invoke({"Raw_Prompt": request.Raw_Prompt}) |
| |
| response = llm.invoke(formatted_prompt) |
| |
| enhanced_prompt = response.content |
| return {"raw_prompt": request.raw_prompt, "enhanced_prompt": enhanced_prompt} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| def generate_image_from_prompt(image_prompt: str): |
| response = client1.models.generate_content( |
| model="gemini-2.0-flash-exp-image-generation", |
| contents=image_prompt, |
| config=types.GenerateContentConfig( |
| response_modalities=['Text', 'Image'] |
| ) |
| ) |
| |
| result = {"text": None, "image_base64": None} |
| |
| for part in response.candidates[0].content.parts: |
| if part.text is not None: |
| result["text"] = part.text |
| elif part.inline_data is not None: |
| |
| image = Image.open(BytesIO(part.inline_data.data)) |
| |
| buffered = BytesIO() |
| image.save(buffered, format="PNG") |
| |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| result["image_base64"] = img_str |
| return result |
|
|
| |
| @app.post("/generate-poster") |
| async def generate_image(request: GenerateImageRequest): |
| try: |
| |
| if request.enhanced_prompt: |
| image_prompt = request.enhanced_prompt |
| elif request.raw_prompt: |
| image_prompt = request.Raw_Prompt |
| else: |
| raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.") |
| |
| result = generate_image_from_prompt(image_prompt) |
| return result |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| @app.post("/update-poster") |
| async def update_image(request: UpdateImageRequest): |
| try: |
| |
| image_bytes = base64.b64decode(request.image_base64) |
| image = Image.open(BytesIO(image_bytes)) |
| |
| |
| response = client1.models.generate_content( |
| model="gemini-2.0-flash-exp-image-generation", |
| contents=[request.text_instruction, image], |
| config=types.GenerateContentConfig( |
| response_modalities=['Text', 'Image'] |
| ) |
| ) |
| updated_result = {"text": None, "updated_image_base64": None} |
| for part in response.candidates[0].content.parts: |
| if part.text is not None: |
| updated_result["text"] = part.text |
| elif part.inline_data is not None: |
| updated_image = Image.open(BytesIO(part.inline_data.data)) |
| buffered = BytesIO() |
| updated_image.save(buffered, format="PNG") |
| updated_img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| updated_result["updated_image_base64"] = updated_img_str |
| return updated_result |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|