Spaces:
Running
Running
| import os | |
| from dotenv import load_dotenv | |
| import streamlit as st | |
| import requests | |
| from PIL import Image, ImageDraw, ImageFont | |
| import io | |
| import base64 | |
| import easyocr | |
| import numpy as np | |
| import cv2 | |
| # Load environment variables | |
| load_dotenv() | |
| # Set up logging | |
| import logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| # Hugging Face API setup | |
| API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| # Initialize EasyOCR reader | |
| reader = easyocr.Reader(['en']) | |
| def query(payload): | |
| try: | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| response.raise_for_status() | |
| logger.debug(f"API response status code: {response.status_code}") | |
| logger.debug(f"API response headers: {response.headers}") | |
| content_type = response.headers.get('Content-Type', '') | |
| if 'application/json' in content_type: | |
| return response.json() | |
| elif 'image' in content_type: | |
| return response.content | |
| else: | |
| logger.error(f"Unexpected content type: {content_type}") | |
| st.error(f"Unexpected content type: {content_type}") | |
| return None | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Request failed: {str(e)}") | |
| st.error(f"Request failed: {str(e)}") | |
| return None | |
| def increase_image_quality(image, scale_factor): | |
| width, height = image.size | |
| new_size = (width * scale_factor, height * scale_factor) | |
| return image.resize(new_size, Image.LANCZOS) | |
| def extract_text_from_image(image): | |
| img_array = np.array(image) | |
| results = reader.readtext(img_array) | |
| return ' '.join([result[1] for result in results]) | |
| def remove_text_from_image(image, text_to_remove): | |
| img_array = np.array(image) | |
| results = reader.readtext(img_array) | |
| for (bbox, text, prob) in results: | |
| if text_to_remove.lower() in text.lower(): | |
| top_left = tuple(map(int, bbox[0])) | |
| bottom_right = tuple(map(int, bbox[2])) | |
| # Convert image to OpenCV format | |
| img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| # Create a mask for inpainting | |
| mask = np.zeros(img_cv.shape[:2], dtype=np.uint8) | |
| cv2.rectangle(mask, top_left, bottom_right, (255, 255, 255), -1) | |
| # Perform inpainting | |
| inpainted = cv2.inpaint(img_cv, mask, 3, cv2.INPAINT_TELEA) | |
| # Convert back to PIL Image | |
| image = Image.fromarray(cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB)) | |
| return image, top_left, (bottom_right[0] - top_left[0], bottom_right[1] - top_left[1]) | |
| logger.warning(f"Text '{text_to_remove}' not found in the image.") | |
| return image, None, None | |
| def add_text_to_image(image, text, font_size=40, font_color="#FFFFFF", position=None, size=None): | |
| draw = ImageDraw.Draw(image) | |
| try: | |
| font = ImageFont.truetype("Roboto-Bold.ttf", font_size) | |
| except IOError: | |
| logger.warning("Roboto-Bold font not found, using default font") | |
| font = ImageFont.load_default() | |
| img_width, img_height = image.size | |
| if position is None or size is None: | |
| # Calculate the center position if no position is provided | |
| bbox = font.getbbox(text) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| position = ((img_width - text_width) // 2, (img_height - text_height) // 2) | |
| size = (text_width, text_height) | |
| # Adjust font size to fit within the given size | |
| while font.getbbox(text)[2] - font.getbbox(text)[0] > size[0] or font.getbbox(text)[3] - font.getbbox(text)[1] > size[1]: | |
| font_size -= 1 | |
| font = ImageFont.truetype("Roboto-Bold.ttf", font_size) | |
| # Use the exact position of the removed text | |
| logger.debug(f"Adding text at position: {position}") | |
| draw.text(position, text, font=font, fill=font_color) | |
| return image | |
| def main(): | |
| st.title("Poster Generator and Editor") | |
| # Image Generation | |
| st.header("Generate Poster") | |
| poster_type = st.selectbox("Poster Type", ["Fashion", "Movie", "Event", "Advertisement", "Other"]) | |
| prompt = st.text_area("Prompt") | |
| num_images = st.number_input("Number of Images", min_value=1, max_value=5, value=1) | |
| quality_factor = st.number_input("Quality Factor", min_value=1, max_value=4, value=1) | |
| if st.button("Generate Images"): | |
| if poster_type == "Other": | |
| full_prompt = f"A colorful poster with the following elements: {prompt}" | |
| else: | |
| full_prompt = f"A colorful {poster_type.lower()} poster with the following elements: {prompt}" | |
| generated_images = [] | |
| for i in range(num_images): | |
| with st.spinner(f"Generating image {i+1}..."): | |
| logger.info(f"Generating image {i+1} with prompt: {full_prompt}") | |
| response = query({"inputs": full_prompt}) | |
| if isinstance(response, bytes): | |
| image = Image.open(io.BytesIO(response)) | |
| if quality_factor > 1: | |
| image = increase_image_quality(image, quality_factor) | |
| generated_images.append(image) | |
| else: | |
| st.error("Failed to generate image") | |
| # Display generated images | |
| for i, img in enumerate(generated_images): | |
| st.image(img, caption=f"Generated Poster {i+1}", use_column_width=True) | |
| # Save image to session state for editing | |
| img_byte_arr = io.BytesIO() | |
| img.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| st.session_state[f'image_{i}'] = img_byte_arr | |
| # Image Editing | |
| st.header("Edit Poster") | |
| image_to_edit = st.selectbox("Select Image to Edit", [f"Generated Poster {i+1}" for i in range(len(st.session_state.keys()))]) | |
| if image_to_edit: | |
| image_index = int(image_to_edit.split()[-1]) - 1 | |
| img_bytes = st.session_state[f'image_{image_index}'] | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| st.image(img, caption="Current Image", use_column_width=True) | |
| text_to_remove = st.text_input("Text to Remove") | |
| new_text = st.text_input("New Text") | |
| font_size = st.number_input("Font Size", min_value=1, max_value=100, value=40) | |
| font_color = st.color_picker("Font Color", "#FFFFFF") | |
| if st.button("Apply Changes"): | |
| position = None | |
| size = None | |
| if text_to_remove: | |
| img, position, size = remove_text_from_image(img, text_to_remove) | |
| if new_text: | |
| img = add_text_to_image(img, new_text, font_size, font_color, position, size) | |
| st.image(img, caption="Edited Image", use_column_width=True) | |
| # Save edited image for download | |
| img_byte_arr = io.BytesIO() | |
| img.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| st.download_button( | |
| label="Download Edited Image", | |
| data=img_byte_arr, | |
| file_name="edited_poster.png", | |
| mime="image/png" | |
| ) | |
| if __name__ == "__main__": | |
| main() |