|
|
import os |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from io import BytesIO |
|
|
import base64 |
|
|
|
|
|
|
|
|
from langchain_groq import ChatGroq |
|
|
from langchain_core.prompts import PromptTemplate |
|
|
|
|
|
|
|
|
from google import genai |
|
|
from google.genai import types |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY") |
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
|
|
|
if not LANGCHAIN_API_KEY or not GOOGLE_API_KEY: |
|
|
raise EnvironmentError("API keys for LANGCHAIN_API_KEY and GOOGLE_API_KEY must be set as environment variables.") |
|
|
|
|
|
def get_llm(api_key: str): |
|
|
""" |
|
|
Returns the ChatGroq LLM instance. |
|
|
""" |
|
|
llm = ChatGroq( |
|
|
model="meta-llama/llama-4-scout-17b-16e-instruct", |
|
|
temperature=1, |
|
|
max_tokens=1024, |
|
|
api_key=api_key |
|
|
) |
|
|
return llm |
|
|
|
|
|
|
|
|
llm = get_llm(LANGCHAIN_API_KEY) |
|
|
|
|
|
|
|
|
prompt_template = PromptTemplate.from_template(''' |
|
|
You are a Prompt Enhancement AI Assistant. Your task is to take the user's raw poster image prompt and convert it into a detailed, professional prompt optimized for generating high-quality AI poster. |
|
|
|
|
|
Enhance the prompt by including relevant details such as: |
|
|
- Camera specifications (e.g., lens type, aperture, focal length) |
|
|
- Lighting setup (e.g., natural light, studio lighting, soft shadows) |
|
|
- Camera angle (e.g., top-down, macro, isometric, side view) |
|
|
- Background style (e.g., plain white, minimalistic, outdoor, studio backdrop) |
|
|
- Scene composition (e.g., centered product, depth of field, reflections) |
|
|
- Call-to-Action: include a clear, concise CTA in the poster text (e.g., "Shop Now", "Learn More") |
|
|
|
|
|
Focus only on *Poster photography* — do not include humans or models. |
|
|
|
|
|
Raw Prompt: |
|
|
{Raw_Prompt} |
|
|
|
|
|
Enhanced Prompt: |
|
|
''') |
|
|
|
|
|
|
|
|
client1 = genai.Client(api_key=GOOGLE_API_KEY) |
|
|
|
|
|
|
|
|
|
|
|
class EnhancePromptRequest(BaseModel): |
|
|
raw_prompt: str |
|
|
|
|
|
class GenerateImageRequest(BaseModel): |
|
|
|
|
|
raw_prompt: str = None |
|
|
enhanced_prompt: str = None |
|
|
|
|
|
class UpdateImageRequest(BaseModel): |
|
|
text_instruction: str |
|
|
|
|
|
image_base64: str |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Image Generation & Update API") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "Welcome to the Image Generation API!"} |
|
|
|
|
|
|
|
|
@app.post("/enhance-prompt") |
|
|
async def enhance_prompt(request: EnhancePromptRequest): |
|
|
try: |
|
|
|
|
|
formatted_prompt = prompt_template.invoke({"Raw_Prompt": request.raw_prompt}) |
|
|
|
|
|
response = llm.invoke(formatted_prompt) |
|
|
|
|
|
enhanced_prompt = response.content |
|
|
return {"raw_prompt": request.raw_prompt, "enhanced_prompt": enhanced_prompt} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
def generate_image_from_prompt(image_prompt: str): |
|
|
response = client1.models.generate_content( |
|
|
model="gemini-2.0-flash-exp-image-generation", |
|
|
contents=image_prompt, |
|
|
config=types.GenerateContentConfig( |
|
|
response_modalities=['Text', 'Image'] |
|
|
) |
|
|
) |
|
|
|
|
|
result = {"text": None, "image_base64": None} |
|
|
|
|
|
for part in response.candidates[0].content.parts: |
|
|
if part.text is not None: |
|
|
result["text"] = part.text |
|
|
elif part.inline_data is not None: |
|
|
|
|
|
image = Image.open(BytesIO(part.inline_data.data)) |
|
|
|
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
|
|
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
result["image_base64"] = img_str |
|
|
return result |
|
|
|
|
|
|
|
|
@app.post("/generate-poster") |
|
|
async def generate_image(request: GenerateImageRequest): |
|
|
try: |
|
|
|
|
|
if request.enhanced_prompt: |
|
|
image_prompt = request.enhanced_prompt |
|
|
elif request.raw_prompt: |
|
|
image_prompt = request.raw_prompt |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.") |
|
|
|
|
|
result = generate_image_from_prompt(image_prompt) |
|
|
return result |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/update-poster") |
|
|
async def update_image(request: UpdateImageRequest): |
|
|
try: |
|
|
|
|
|
image_bytes = base64.b64decode(request.image_base64) |
|
|
image = Image.open(BytesIO(image_bytes)) |
|
|
|
|
|
|
|
|
response = client1.models.generate_content( |
|
|
model="gemini-2.0-flash-exp-image-generation", |
|
|
contents=[request.text_instruction, image], |
|
|
config=types.GenerateContentConfig( |
|
|
response_modalities=['Text', 'Image'] |
|
|
) |
|
|
) |
|
|
updated_result = {"text": None, "updated_image_base64": None} |
|
|
for part in response.candidates[0].content.parts: |
|
|
if part.text is not None: |
|
|
updated_result["text"] = part.text |
|
|
elif part.inline_data is not None: |
|
|
updated_image = Image.open(BytesIO(part.inline_data.data)) |
|
|
buffered = BytesIO() |
|
|
updated_image.save(buffered, format="PNG") |
|
|
updated_img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
updated_result["updated_image_base64"] = updated_img_str |
|
|
return updated_result |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|