import random import base64 import io import math import urllib.request import fal_client import gradio as gr import numpy as np import plotly.graph_objects as go from PIL import Image # ── Constantes poses (LoRA fal Multiple-Angles) ─────────────────────────────── AZIMUTHS = [0, 45, 90, 135, 180, 225, 270, 315] ELEVATIONS = [-30, 0, 30, 60] DISTANCES = [0.6, 1.0, 1.8] AZIMUTH_NAMES = { 0: "front view", 45: "front-right quarter view", 90: "right side view", 135: "back-right quarter view", 180: "back view", 225: "back-left quarter view", 270: "left side view", 315: "front-left quarter view", } ELEVATION_NAMES = {-30: "low-angle shot", 0: "eye-level shot", 30: "elevated shot", 60: "high-angle shot"} DISTANCE_NAMES = {0.6: "close-up", 1.0: "medium shot", 1.8: "wide shot"} # ── Helpers ─────────────────────────────────────────────────────────────────── def build_prompt(az, el, di): return f" {AZIMUTH_NAMES[az]}, {ELEVATION_NAMES[el]}, {DISTANCE_NAMES[di]}" def image_to_uri(img): if img is None: return "" buf = io.BytesIO() img.save(buf, format="PNG") b64 = base64.b64encode(buf.getvalue()).decode() return f"data:image/png;base64,{b64}" # ── Viewer 2D Plotly cliquable ──────────────────────────────────────────────── # # Conception : vue top-down. Le sujet est au centre (carré bleu). # 4 cercles concentriques = 4 niveaux d'élévation. # 8 points par cercle = 8 azimuts. → 32 poses cliquables. # Index dans la trace : i = az_idx * 4 + el_idx. ELEVATION_COLORS = {-30: "#4a3a6a", 0: "#00ff88", 30: "#ffa500", 60: "#ff69b4"} ELEVATION_LABELS = {-30: "Low", 0: "Eye", 30: "Elev", 60: "High"} RADIUS_BY_EL = {-30: 1.2, 0: 2.0, 30: 2.7, 60: 3.3} def build_viewer(current_az, current_el, current_di): fig = go.Figure() # ── Cercles guides (un par élévation) ───────────────────────────────────── th = np.linspace(0, 2 * np.pi, 100) for el, r in RADIUS_BY_EL.items(): fig.add_trace(go.Scatter( x=r * np.sin(th), y=r * np.cos(th), mode="lines", line=dict(color=ELEVATION_COLORS[el], width=1, dash="dot"), opacity=0.35, hoverinfo="skip", showlegend=False, )) # ── Sujet au centre ─────────────────────────────────────────────────────── fig.add_trace(go.Scatter( x=[0], y=[0], mode="markers+text", marker=dict(size=38, color="#3a4a8a", symbol="square", line=dict(color="#8aa8ff", width=2)), text=["📷"], textfont=dict(size=20, color="white"), hoverinfo="skip", showlegend=False, )) # ── Labels cardinaux ────────────────────────────────────────────────────── fig.add_trace(go.Scatter( x=[0, 3.7, 0, -3.7], y=[3.7, 0, -3.7, 0], mode="text", text=["FRONT 0°", "RIGHT 90°", "BACK 180°", "LEFT 270°"], textfont=dict(size=13, color="#cccccc"), hoverinfo="skip", showlegend=False, )) # ── 32 points cliquables (8 az × 4 el) ──────────────────────────────────── xs, ys, colors, sizes, texts, hovers = [], [], [], [], [], [] for az in AZIMUTHS: for el in ELEVATIONS: r = RADIUS_BY_EL[el] x = r * math.sin(math.radians(az)) y = r * math.cos(math.radians(az)) is_current = (az == current_az and el == current_el) xs.append(x); ys.append(y) colors.append("#ffffff" if is_current else ELEVATION_COLORS[el]) sizes.append(28 if is_current else 16) texts.append("●" if is_current else "") hovers.append(f"{AZIMUTH_NAMES[az]}
{ELEVATION_NAMES[el]}
cliquer pour sélectionner") fig.add_trace(go.Scatter( x=xs, y=ys, mode="markers", marker=dict(size=sizes, color=colors, line=dict(color="white", width=1)), hovertext=hovers, hoverinfo="text", showlegend=False, name="poses", )) # ── Légende élévation (coin haut-droit) ─────────────────────────────────── legend_y = 4.0 for i, (el, col) in enumerate(ELEVATION_COLORS.items()): fig.add_trace(go.Scatter( x=[3.0], y=[legend_y - i * 0.4], mode="markers+text", marker=dict(size=12, color=col), text=[f" {ELEVATION_LABELS[el]} ({el}°)"], textposition="middle right", textfont=dict(size=11, color="#cccccc"), hoverinfo="skip", showlegend=False, )) prompt = build_prompt(current_az, current_el, current_di) fig.update_layout( paper_bgcolor="#1a1a1a", plot_bgcolor="#1a1a1a", margin=dict(l=10, r=10, t=10, b=50), xaxis=dict(visible=False, range=[-4.5, 4.8], scaleanchor="y", scaleratio=1), yaxis=dict(visible=False, range=[-4.5, 4.5]), height=460, annotations=[dict( text=f"{prompt}", showarrow=False, xref="paper", yref="paper", x=0.5, y=-0.02, font=dict(family="monospace", size=13, color="#00ff88"), bgcolor="rgba(0,0,0,0.85)", borderpad=8, )], ) return fig # ── Inférence fal.ai ────────────────────────────────────────────────────────── def infer(image, az, el, di, seed, randomize_seed): if image is None: raise gr.Error("Upload a source image first") if randomize_seed: seed = random.randint(0, 2**31 - 1) prompt = build_prompt(az, el, di) result = fal_client.run( "fal-ai/qwen-image-edit", arguments={ "image_url": image_to_uri(image), "prompt": prompt, "seed": seed, "image_size": {"width": 1024, "height": 1024}, "num_inference_steps": 4, "guidance_scale": 1.0, "loras": [{"path": "fal/Qwen-Image-Edit-2511-Multiple-Angles-LoRA", "scale": 1.0}], }, ) with urllib.request.urlopen(result["images"][0]["url"]) as resp: out_img = Image.open(io.BytesIO(resp.read())).convert("RGB") return out_img, seed, prompt # ── UI Gradio ───────────────────────────────────────────────────────────────── AZ_LABELS = {0: "⬆ Front", 45: "↗ Front-R", 90: "➡ Right", 135: "↘ Back-R", 180: "⬇ Back", 225: "↙ Back-L", 270: "⬅ Left", 315: "↖ Front-L"} EL_LABELS = {-30: "⬇ Low (-30°)", 0: "➡ Eye (0°)", 30: "⬈ Elev (+30°)", 60: "⬆ High (+60°)"} DI_LABELS = {0.6: "🔍 Close-up", 1.0: "📷 Medium", 1.8: "🌄 Wide"} with gr.Blocks(title="Angle Studio") as demo: gr.Markdown(""" # 🎥 Angle Studio **Cliquez 1 azimut, 1 élévation, 1 distance — puis générez.** *Pick one azimuth, one elevation, one distance — then generate.* """) az_state = gr.State(0) el_state = gr.State(0) di_state = gr.State(1.0) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Image source / Source image", type="pil") gr.Markdown("### 🧭 Azimut") with gr.Row(): az_buttons = [gr.Button(AZ_LABELS[a], size="sm", variant=("primary" if a == 0 else "secondary")) for a in AZIMUTHS] gr.Markdown("### 📐 Élévation") with gr.Row(): el_buttons = [gr.Button(EL_LABELS[e], size="sm", variant=("primary" if e == 0 else "secondary")) for e in ELEVATIONS] gr.Markdown("### 🔭 Distance") with gr.Row(): di_buttons = [gr.Button(DI_LABELS[d], size="sm", variant=("primary" if d == 1.0 else "secondary")) for d in DISTANCES] prompt_preview = gr.Textbox( label="Prompt", interactive=False, value=build_prompt(0, 0, 1.0), ) with gr.Row(): seed_input = gr.Number(label="Seed", value=0, precision=0) randomize = gr.Checkbox(label="Random seed", value=True) generate_btn = gr.Button("▶ Générer / Generate", variant="primary", size="lg") with gr.Column(scale=1): viewer = gr.Plot(value=build_viewer(0, 0, 1.0), show_label=False) output_image = gr.Image(label="Résultat / Result", type="pil") output_seed = gr.Number(label="Seed utilisé", interactive=False) gr.Markdown("### 🖼️ Galerie de session / Session Gallery") gallery = gr.Gallery(columns=4, height=260) session_images = gr.State([]) # ── Handlers : chaque bouton met à jour son état + viewer + prompt ─────── def make_az_handler(a): def _fn(el, di): return a, build_viewer(a, el, di), build_prompt(a, el, di) return _fn def make_el_handler(e): def _fn(az, di): return e, build_viewer(az, e, di), build_prompt(az, e, di) return _fn def make_di_handler(d): def _fn(az, el): return d, build_viewer(az, el, d), build_prompt(az, el, d) return _fn for a, btn in zip(AZIMUTHS, az_buttons): btn.click(fn=make_az_handler(a), inputs=[el_state, di_state], outputs=[az_state, viewer, prompt_preview]) for e, btn in zip(ELEVATIONS, el_buttons): btn.click(fn=make_el_handler(e), inputs=[az_state, di_state], outputs=[el_state, viewer, prompt_preview]) for d, btn in zip(DISTANCES, di_buttons): btn.click(fn=make_di_handler(d), inputs=[az_state, el_state], outputs=[di_state, viewer, prompt_preview]) def run_and_append(image, az, el, di, seed, rand, history): result, used_seed, _ = infer(image, az, el, di, seed, rand) history = history + [result] return result, used_seed, history, history generate_btn.click( fn=run_and_append, inputs=[input_image, az_state, el_state, di_state, seed_input, randomize, session_images], outputs=[output_image, output_seed, session_images, gallery], ) demo.launch(theme=gr.themes.Base())