Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from langchain.memory import ConversationTokenBufferMemory | |
| import os | |
| import json | |
| import requests | |
| import time | |
| from PIL import Image | |
| from io import BytesIO | |
| from dotenv import load_dotenv | |
| import tempfile | |
| from fastapi.responses import FileResponse | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Advanced AI Mock-up FastAPI") | |
| # Configure API Keys and global dependencies | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| url = os.getenv("IMAGE_API_URL") | |
| API_KEY = os.getenv("IMAGE_API_KEY") | |
| if not OPENAI_API_KEY or not API_KEY: | |
| raise EnvironmentError("Missing API keys. Please set OPENAI_API_KEY and IMAGE_API_KEY in the environment variables.") | |
| from openai import OpenAI | |
| # Configure OpenAI client | |
| client = OpenAI() | |
| from langchain_openai import OpenAI | |
| llm = OpenAI() | |
| memory = ConversationTokenBufferMemory(llm=llm, max_token_limit=4000) | |
| # API Key and Headers for image generation | |
| headers = { | |
| "accept": "application/json", | |
| "x-key": API_KEY, | |
| "Content-Type": "application/json" | |
| } | |
| # Pydantic model for input | |
| class ConversationRequest(BaseModel): | |
| question: str | |
| # Function to manage greeting | |
| def Greeting(question, chat_history): | |
| prompt = f""" | |
| You are a professional AI assistant specialized in AI-powered mock-up creation. Start with a warm greeting, ask about the user's well-being, and also ask related to AI-powered mock-up creation for jackets or other apparel. Tailor your conversation to establish a friendly and professional tone. | |
| Chat History: | |
| {chat_history} | |
| """ | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": f"Question: {question}"} | |
| ] | |
| ) | |
| return response.choices[0].message.content | |
| def select_state(chat_history): | |
| output_format = ''' | |
| Answer according to the following JSON format: | |
| { | |
| "State": "Here you will select one state based on chat history: 'greeting', 'gather_info', 'analyze_chat_history', 'generate_images'" | |
| }''' | |
| prompt = f""" | |
| Based on the below chat history, decide the state for the agent. The state can be: | |
| - 'greeting': if the chat history lacks a greeting message. | |
| - 'gather_info': if greeting messages (like 'hi', 'hello', 'how are you') have been successfully executed. | |
| - 'analyze_chat_history': if sufficient information has been gathered, including: | |
| - Team name. | |
| - Colors or style preferences. | |
| - Details about patterns, or any unique requirements. | |
| - 'generate_images': if image prompts are generated. | |
| Chat History: | |
| {chat_history} | |
| """ + output_format | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": "Select the next state"} | |
| ] | |
| ) | |
| json_data = json.loads(response.choices[0].message.content) | |
| return json_data['State'] | |
| # Function to gather information | |
| def Gather_info(question, chat_history): | |
| prompt = f""" | |
| You are an information-gathering agent specialized in AI-powered mock-up creation. Your task is to politely gather the following information from the user: | |
| - Team: Ask what team this is for. | |
| - Team colors: Ask for the team colors or other specific colors they want to use. | |
| - Style guide: Inquire if the user can provide a details to the team style guide. | |
| - Additional details: Gather any additional specific information related to the Team, such as patterns, or any unique requirements. | |
| Please ask these questions one by one in a friendly and engaging manner, and ensure you document all the provided details accurately. | |
| Chat History: | |
| {chat_history} | |
| """ | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": f"Question: {question}"} | |
| ] | |
| ) | |
| return response.choices[0].message.content | |
| def analyze_chat_history(chat_history): | |
| output_format = ''' | |
| Answer according to the following JSON format: | |
| { | |
| "Analysis": "Provide a summary analysis of the chat history, focusing on key insights derived from the gathered information.", | |
| "NextAction": "Specify the next logical action: either continue the conversation or conclude it.", | |
| "Prompts": [ | |
| "Prompt 1: Detailed prompt for generating the first mock-up.", | |
| "Prompt 2: Detailed prompt for generating the second mock-up.", | |
| "Prompt 3: Detailed prompt for generating the third mock-up.", | |
| "Prompt 4: Detailed prompt for generating the fourth mock-up." | |
| ] | |
| }''' | |
| prompt = f""" | |
| You are a highly intelligent and efficient analysis agent tasked with processing the chat history provided below. Based solely on the relevant information gathered by the information-gathering agent, your responsibilities are to: | |
| 1. Summarize the user's key points and design requirements with precision, highlighting the essential elements. | |
| 2. Generate 4 detailed and creative prompts for image mock-ups tailored to the user's specific needs. | |
| 3. In all the prompts the information about the jacket should be same so jacket in all the images are same but have different view. | |
| Ensure that the generated prompts adhere to the following criteria: | |
| - Visually compelling, emphasizing creativity, detail, and storytelling. | |
| - Highly specific, incorporating the following aspects where applicable: | |
| - Key themes, team dynamics, or user-specified concepts. | |
| - Color schemes, textures, and style guidelines. | |
| - Camera and Lens Settings: Recommend camera models (e.g., Canon EOS R5, Nikon Z9), lenses (e.g., 50mm f/1.8 for portraits or 85mm for close-ups), and techniques (e.g., shallow depth of field, macro for texture). | |
| - Artistic Enhancements: Suggest details like angles (e.g., low-angle, top-down), effects (e.g., bokeh, soft focus), or scene accents (e.g., props or natural textures). | |
| - Aspect Ratio and Style Tags: Specify dimensions (e.g., --ar 16:9 for banners or --ar 4:5 for Instagram). Include style tags like --style cinematic, --style raw, or --style editorial. | |
| - Lighting details, including time of day, intensity, direction, and color temperature. | |
| - Composition elements like framing, depth of field, symmetry, and rule of thirds. | |
| - Environmental and contextual details that provide additional realism or artistic flair. | |
| - Clearly structured to provide effective guidance for advanced image generation models. | |
| - Prompts should Focus on the provided color combination. Do not add anything from yourself use all the context that user have provided | |
| - Do not add Humans in the images. Only generate the images of the jackets in the white background | |
| Chat History: | |
| {chat_history} | |
| """ + output_format | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": "Analyze the conversation and generate prompts"} | |
| ] | |
| ) | |
| # Extract response content | |
| response_content = response.choices[0].message.content.strip() | |
| # Clean and validate response content | |
| if response_content.startswith("```") and response_content.endswith("```"): | |
| response_content = response_content[response_content.find("\n") + 1 : -3].strip() | |
| if not response_content: | |
| raise ValueError("The API returned an empty response.") | |
| try: | |
| json_data = json.loads(response_content) | |
| except json.JSONDecodeError as e: | |
| print("Error parsing JSON:", e) | |
| print("Content causing error:", response_content) | |
| raise | |
| return json_data | |
| # Create a temporary directory to store generated images | |
| temp_dir = tempfile.TemporaryDirectory() | |
| def generate_images(prompts, url, headers): | |
| temp_dir = tempfile.TemporaryDirectory() | |
| image_links = [] | |
| for index, prompt in enumerate(prompts, start=1): | |
| print(f"Generating image {index} of {len(prompts)}") | |
| payload = { | |
| "prompt": prompt, | |
| "width": 1024, | |
| "height": 1024, | |
| "guidance_scale": 1, | |
| "num_inference_steps": 50, | |
| "max_sequence_length": 512, | |
| } | |
| response = requests.post(url, headers=headers, json=payload).json() | |
| if "id" not in response: | |
| print("Error in generating image:", response) | |
| continue | |
| request_id = response["id"] | |
| print(f"Image generation request ID for prompt {index}: {request_id}") | |
| while True: | |
| time.sleep(0.5) | |
| result = requests.get( | |
| "https://api.bfl.ml/v1/get_result", | |
| headers=headers, | |
| params={"id": request_id}, | |
| ).json() | |
| if result["status"] == "Ready": | |
| if "result" in result and "sample" in result["result"]: | |
| image_url = result["result"]["sample"] | |
| print(f"Generated image URL for prompt {index}: {image_url}") | |
| image_links.append(image_url) | |
| else: | |
| print(f"Error: 'sample' key not found in the result for prompt {index}.") | |
| break | |
| else: | |
| print(f"Image generation status for prompt {index}: {result['status']}") | |
| return image_links | |
| def manage_conversation(question, url, headers, memory): | |
| chat_history = memory.load_memory_variables({}) | |
| chat_history = chat_history['history'] | |
| # Get the current state | |
| state = select_state(chat_history) | |
| if state == "greeting": | |
| response = Greeting(question, chat_history) | |
| elif state == "gather_info": | |
| response = Gather_info(question, chat_history) | |
| elif state == "analyze_chat_history": | |
| response = analyze_chat_history(chat_history) | |
| # Serialize the JSON response to a string if it's a dictionary | |
| response = json.dumps(response, indent=4) | |
| elif state == "generate_images": | |
| prompts = analyze_chat_history(chat_history)['Prompts'] | |
| image_links = generate_images(prompts, url=url, headers=headers) | |
| response = json.dumps({"message": "Images generated successfully.", "image_links": image_links}, indent=4) | |
| else: | |
| response = "Conversation ended." | |
| # Save the response to memory as a string | |
| memory.save_context({"input": question}, {"output": response}) | |
| return response | |
| # API Endpoint | |
| async def conversation_endpoint(request: ConversationRequest): | |
| try: | |
| response = manage_conversation(request.question, url, headers, memory) # Pass the required arguments | |
| return {"response": response} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def new_chat(): | |
| """ | |
| This endpoint resets the memory and starts a new chat session. | |
| """ | |
| try: | |
| global memory | |
| memory = ConversationTokenBufferMemory(llm=llm, max_token_limit=4000) | |
| return {"message": "New chat session started successfully."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def root(): | |
| return {"message": "API is up and running!"} | |