top_model / app.py
uhdessai's picture
Update app.py
9d515e8 verified
# import gradio as gr
# import subprocess
# import os
# import random
# from PIL import Image
# import shutil
# import requests
# # === Setup Paths ===
# BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# GEN_SCRIPT = os.path.join(BASE_DIR, "stylegan3", "gen_images.py")
# OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
# MODEL_PATH = os.path.join(BASE_DIR, "top_model.pkl")
# SAVE_DIR = os.path.join(BASE_DIR, "saved_images")
# os.makedirs(OUTPUT_DIR, exist_ok=True)
# os.makedirs(SAVE_DIR, exist_ok=True)
# # === Image Generation Function ===
# def generate_images():
# command = [
# "python",
# GEN_SCRIPT,
# f"--outdir={OUTPUT_DIR}",
# "--trunc=1",
# "--seeds=3-5,7,9,12-14,16-26,29,31,32,34,40,41",
# f"--network={MODEL_PATH}"
# ]
# try:
# subprocess.run(command, check=True, capture_output=True, text=True)
# except subprocess.CalledProcessError as e:
# return f"Error generating images:\n{e.stderr}"
# # === Select Random Images from Output Folder ===
# def get_random_images():
# image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
# if len(image_files) < 10:
# generate_images()
# image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
# random_images = random.sample(image_files, min(10, len(image_files)))
# image_paths = [os.path.join(OUTPUT_DIR, img) for img in random_images]
# return image_paths
# # === Send Image to Backend ===
# def send_to_backend(img_path, user_id):
# if not user_id:
# return "❌ user_id not found in URL."
# if not img_path or not os.path.exists(img_path):
# return "⚠️ No image selected or image not found."
# try:
# with open(img_path, 'rb') as f:
# files = {'file': ('generated_image.png', f, 'image/png')}
# # Your backend endpoint here
# url = f" https://7da2-2409-4042-6e81-1806-de6-b8e5-836c-6b95.ngrok-free.app/images/upload/{user_id}"
# response = requests.post(url, files=files)
# if response.status_code == 201:
# return "✅ Image uploaded and saved to database!"
# else:
# return f"❌ Upload failed: {response.status_code} - {response.text}"
# except Exception as e:
# return f"⚠️ Error: {str(e)}"
# # === Gradio Interface ===
# with gr.Blocks() as demo:
# gr.Markdown("# 🎨 AI-Generated Clothing Designs - Tops")
# generate_button = gr.Button("Generate New Designs")
# user_id_state = gr.State()
# @demo.load(inputs=None, outputs=[user_id_state])
# def get_user_id(request: gr.Request):
# return request.query_params.get("user_id", "")
# image_components = []
# file_paths = []
# save_buttons = []
# outputs = []
# # Use 3 columns layout
# for row_idx in range(4): # 4 rows (to cover 10 images)
# with gr.Row():
# for col_idx in range(3): # 3 columns
# i = row_idx * 3 + col_idx
# if i >= 10:
# break
# with gr.Column():
# img = gr.Image(width=180, height=180, label=f"Design {i+1}")
# image_components.append(img)
# file_path = gr.Textbox(visible=False)
# file_paths.append(file_path)
# save_btn = gr.Button("💾 Save to DB")
# save_buttons.append(save_btn)
# output = gr.Textbox(label="Status", interactive=False)
# outputs.append(output)
# save_btn.click(
# fn=send_to_backend,
# inputs=[file_path, user_id_state],
# outputs=output
# )
# # Generate button logic
# def generate_and_display_images():
# image_paths = get_random_images()
# return image_paths + image_paths # One for display, one for hidden path tracking
# generate_button.click(
# fn=generate_and_display_images,
# outputs=image_components + file_paths
# )
# if __name__ == "__main__":
# demo.launch()
import torch
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import numpy as np
import pickle
import gradio as gr
import tempfile
# Force CPU usage for optimization
device = torch.device("cpu")
# Load your GAN model
with open("top_model.pkl", "rb") as f:
G = pickle.load(f)['G_ema'].eval().cpu() # Ensure model is in eval mode and on CPU
# Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval().cpu()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# def send_to_backend(img_path, user_id):
# if not user_id:
# return "❌ user_id not found in URL."
# if not img_path or not os.path.exists(img_path):
# return "⚠️ No image selected or image not found."
# try:
# with open(img_path, 'rb') as f:
# files = {'file': ('generated_image.png', f, 'image/png')}
# # Your backend endpoint here
# url = f"https://e335-103-40-74-83.ngrok-free.app/images/upload/{user_id}"
# response = requests.post(url, files=files)
# if response.status_code == 201:
# return "✅ Image uploaded and saved to database!"
# else:
# console.log({response.text})
# return f"❌ Upload failed: {response.status_code} - {response.text}"
# except Exception as e:
# return f"⚠️ Error: {str(e)}"
import os
import requests # Make sure you import this!
def send_to_backend(img_path, user_id):
print(f"💡 [DEBUG] Sending image to backend | img_path={img_path}, user_id={user_id}")
if not user_id:
print("❌ [DEBUG] Missing user_id in URL.")
return "❌ user_id not found in URL."
if not img_path or not os.path.exists(img_path):
print("⚠️ [DEBUG] Image path invalid or does not exist.")
return "⚠️ No image selected or image not found."
try:
with open(img_path, 'rb') as f:
files = {'file': ('generated_image.png', f, 'image/png')}
url = f" https://68be601de1e4.ngrok-free.app/images/upload/{user_id}"
print(f"🔁 [DEBUG] Sending POST to {url}")
response = requests.post(url, files=files)
print(f"📩 [DEBUG] Response: {response.status_code} - {response.text}")
if response.status_code == 201 or response.status_code == 200:
return "✅ Image uploaded and saved to database!"
else:
return f"❌ Upload failed: {response.status_code} - {response.text}"
except Exception as e:
print(f"⚠️ [ERROR] Exception during upload: {str(e)}")
return f"⚠️ Error: {str(e)}"
# Generate images
def generate_images(G, num_images=10): # Reduce for CPU performance
z = torch.randn(num_images, G.z_dim)
c = None
with torch.no_grad():
images = G(z, c)
images = (images.clamp(-1, 1) + 1) * (255 / 2)
images = images.permute(0, 2, 3, 1).numpy().astype(np.uint8)
return z, images
# Rank images using CLIP
def rank_by_clip(images, prompt, top_k=3): # Reduce top_k for speed
images_pil = [Image.fromarray(img) for img in images]
inputs = clip_processor(text=[prompt], images=images_pil, return_tensors="pt", padding=True)
with torch.no_grad():
image_features = clip_model.get_image_features(pixel_values=inputs["pixel_values"])
text_features = clip_model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
similarity = (image_features @ text_features.T).squeeze()
top_indices = similarity.argsort(descending=True)[:top_k]
best_images = [images_pil[i] for i in top_indices]
return best_images
# Gradio interface function
def generate_top_dresses(prompt):
_, images = generate_images(G, num_images=20)
top_images = rank_by_clip(images, prompt, top_k=2)
file_paths = []
for i, img in enumerate(top_images):
temp_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
img.save(temp_path)
file_paths.append(temp_path)
return top_images, file_paths
# Launch Gradio
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("""
# 👗 AI Top Generator
_Type in your dream outfit, and let the AI bring your fashion vision to life!_
Just describe and see how AI transforms your words into fashion.
""")
with gr.Row():
input_box = gr.Textbox(
label="Describe your Design",
placeholder="e.g., 'Black sleeveless crop top'",
lines=2
)
with gr.Row():
submit_button = gr.Button("Generate Designs")
user_id_state = gr.State()
@demo.load(inputs=None, outputs=[user_id_state])
def get_user_id(request: gr.Request):
return request.query_params.get("user_id", "")
image_components = []
file_paths = []
save_buttons = []
outputs = []
with gr.Row():
for i in range(2): # Only 2 images
with gr.Column():
img = gr.Image(width=180, height=180, label=f"Design {i+1}")
image_components.append(img)
file_path = gr.Textbox(visible=False)
file_paths.append(file_path)
save_btn = gr.Button("💾 Save to DB")
save_buttons.append(save_btn)
output = gr.Textbox(label="Status", interactive=False)
outputs.append(output)
save_btn.click(
fn=send_to_backend,
inputs=[file_path, user_id_state],
outputs=output
)
examples = gr.Examples(
examples = [
["Simple red V-neck top"],
["Simple blue round-neck top with short sleeves"]
],
inputs=[input_box]
)
# Generate button logicg
def generate_and_display_images(prompt):
images, paths = generate_top_dresses(prompt)
return images + paths
submit_button.click(
fn=generate_and_display_images,
inputs=[input_box],
outputs=image_components + file_paths
)
demo.launch()