File size: 6,474 Bytes
ec29232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b0395
 
 
 
 
 
 
 
b813a2a
ec29232
17b0395
 
 
7797f37
ec29232
17b0395
ec29232
 
 
 
 
 
 
 
17b0395
ec29232
 
 
17b0395
ec29232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b0395
ec29232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b0395
ec29232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from io import BytesIO
import base64

# Import the required packages for the language model and prompts
from langchain_groq import ChatGroq
from langchain_core.prompts import PromptTemplate

# Import packages for image processing and Google GenAI client
from google import genai
from google.genai import types
from PIL import Image

# --- Global Initialization ---

# Get API keys from environment variables
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

if not LANGCHAIN_API_KEY or not GOOGLE_API_KEY:
    raise EnvironmentError("API keys for LANGCHAIN_API_KEY and GOOGLE_API_KEY must be set as environment variables.")

def get_llm(api_key: str):
    """
    Returns the ChatGroq LLM instance.
    """
    llm = ChatGroq(
        model="meta-llama/llama-4-scout-17b-16e-instruct",
        temperature=1,
        max_tokens=1024,
        api_key=api_key
    )
    return llm

# Initialize the LLM instance
llm = get_llm(LANGCHAIN_API_KEY)

# Define the prompt template for enhancing raw prompts.
prompt_template = PromptTemplate.from_template('''
You are a Prompt Enhancement AI Assistant. Your task is to take the user's raw poster image prompt and convert it into a detailed, professional prompt optimized for generating high-quality AI poster.

Enhance the prompt by including relevant details such as:
- Camera specifications (e.g., lens type, aperture, focal length)
- Lighting setup (e.g., natural light, studio lighting, soft shadows)
- Camera angle (e.g., top-down, macro, isometric, side view)
- Background style (e.g., plain white, minimalistic, outdoor, studio backdrop)
- Scene composition (e.g., centered product, depth of field, reflections)
- Call-to-Action: include a clear, concise CTA in the poster text (e.g., "Shop Now", "Learn More")

Focus only on *Poster photography* — do not include humans or models.

Raw Prompt:
{Raw_Prompt}

Enhanced Prompt:
''')

# Initialize the Google GenAI client.
client1 = genai.Client(api_key=GOOGLE_API_KEY)

# --- Request Models ---

class EnhancePromptRequest(BaseModel):
    raw_prompt: str

class GenerateImageRequest(BaseModel):
    # If both are provided, enhanced_prompt takes priority.
    raw_prompt: str = None
    enhanced_prompt: str = None

class UpdateImageRequest(BaseModel):
    text_instruction: str
    # Image encoded in base64
    image_base64: str

# --- FastAPI Initialization ---

app = FastAPI(title="Image Generation & Update API")

# 1. Root endpoint
@app.get("/")
async def root():
    return {"message": "Welcome to the Image Generation API!"}

# 2. Enhance Prompt endpoint
@app.post("/enhance-prompt")
async def enhance_prompt(request: EnhancePromptRequest):
    try:
        # Prepare the prompt using the template.
        formatted_prompt = prompt_template.invoke({"Raw_Prompt": request.raw_prompt})
        # Call the LLM to enhance the prompt.
        response = llm.invoke(formatted_prompt)
        # Assume the enhanced prompt is in the response.content.
        enhanced_prompt = response.content
        return {"raw_prompt": request.raw_prompt, "enhanced_prompt": enhanced_prompt}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Helper function to generate image content using GenAI client.
def generate_image_from_prompt(image_prompt: str):
    response = client1.models.generate_content(
        model="gemini-2.0-flash-exp-image-generation",
        contents=image_prompt,
        config=types.GenerateContentConfig(
            response_modalities=['Text', 'Image']
        )
    )
    # Prepare a dict to hold results.
    result = {"text": None, "image_base64": None}
    # Process the response parts.
    for part in response.candidates[0].content.parts:
        if part.text is not None:
            result["text"] = part.text
        elif part.inline_data is not None:
            # Open the image from bytes.
            image = Image.open(BytesIO(part.inline_data.data))
            # Convert image to bytes.
            buffered = BytesIO()
            image.save(buffered, format="PNG")
            # Encode to base64.
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            result["image_base64"] = img_str
    return result

# 3. Generate Image endpoint
@app.post("/generate-poster")
async def generate_image(request: GenerateImageRequest):
    try:
        # Decide which prompt to use.
        if request.enhanced_prompt:
            image_prompt = request.enhanced_prompt
        elif request.raw_prompt:
            image_prompt = request.raw_prompt
        else:
            raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.")
        
        result = generate_image_from_prompt(image_prompt)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# 4. Update Image endpoint
@app.post("/update-poster")
async def update_image(request: UpdateImageRequest):
    try:
        # Decode the incoming base64 image.
        image_bytes = base64.b64decode(request.image_base64)
        image = Image.open(BytesIO(image_bytes))
        
        # The contents for the update API include the text instruction and the current image.
        response = client1.models.generate_content(
            model="gemini-2.0-flash-exp-image-generation",
            contents=[request.text_instruction, image],
            config=types.GenerateContentConfig(
                response_modalities=['Text', 'Image']
            )
        )
        updated_result = {"text": None, "updated_image_base64": None}
        for part in response.candidates[0].content.parts:
            if part.text is not None:
                updated_result["text"] = part.text
            elif part.inline_data is not None:
                updated_image = Image.open(BytesIO(part.inline_data.data))
                buffered = BytesIO()
                updated_image.save(buffered, format="PNG")
                updated_img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
                updated_result["updated_image_base64"] = updated_img_str
        return updated_result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# --- Run the app ---
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)