VetriVendhan26 commited on
Commit
7266287
·
verified ·
1 Parent(s): 92434f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -85
app.py CHANGED
@@ -1,95 +1,46 @@
1
- import os
2
- import sys
3
  import gradio as gr
4
- import subprocess
5
  import torch
6
- from pathlib import Path
7
-
8
- # Add MagicDrive as submodule (auto-clones on first run)
9
- MAGICDRIVE_DIR = Path("./MagicDrive")
10
- if not MAGICDRIVE_DIR.exists():
11
- subprocess.run(["git", "clone", "--recursive", "https://github.com/cure-lab/MagicDrive.git"])
12
- sys.path.append(str(MAGICDRIVE_DIR))
13
- os.chdir(MAGICDRIVE_DIR)
14
-
15
- # Install deps on first run
16
- subprocess.run(["pip", "install", "-r", "requirements/gui.txt"], capture_output=True)
17
- subprocess.run(["pip", "install", "gradio", "omegaconf", "hydra-core"], capture_output=True)
18
-
19
- # Import MagicDrive modules
20
- sys.path.insert(0, str(MAGICDRIVE_DIR))
21
- from demo.interactive_gui import load_model_from, run_pipe
22
- import glob
23
-
24
- # Global model cache
25
- model_cache = None
26
-
27
- def load_magicdrive_model():
28
- global model_cache
29
- if model_cache is None:
30
- # Download minimal pretrained if needed
31
- os.makedirs("pretrained/stable-diffusion-v1-5", exist_ok=True)
32
- model_dir = "pretrained/SDv1.5mv-rawbox_2023-09-07_18-39_224x400"
33
- os.makedirs(model_dir, exist_ok=True)
34
-
35
- model_cache = load_model_from(model_dir, device="cuda" if torch.cuda.is_available() else "cpu")
36
- return model_cache
37
-
38
- def generate_image(image_input, prompt, seed=42):
39
- if image_input is None:
40
- return None, "Upload an image first"
41
-
42
- model = load_magicdrive_model()
43
- cfg, pipe = model
44
-
45
- # Run MagicDrive inference
46
- import numpy as np
47
- from PIL import Image
48
- import random
49
 
50
- random.seed(seed)
51
  torch.manual_seed(seed)
 
 
 
 
 
 
 
52
 
53
- # Convert Gradio image to numpy
54
- img_array = np.array(image_input)
55
-
56
- result = run_pipe(cfg, pipe, img_array, seed)
57
-
58
- # Convert back to PIL
59
- result_img = Image.fromarray((result * 255).astype(np.uint8))
60
- return result_img, f"Generated with seed: {seed}"
61
 
62
- # Gradio Tabs Interface
63
- with gr.Blocks(title="🪄 MagicDrive Demo") as demo:
64
- gr.Markdown("# 🪄 MagicDrive: Street View Generation")
65
- gr.Markdown("Upload image → Generate street view with 3D geometry control")
66
 
67
- with gr.Tabs():
68
- with gr.TabItem("🚗 Street View Generator"):
69
- with gr.Row():
70
- with gr.Column(scale=1):
71
- input_image = gr.Image(type="pil", label="Input Image")
72
- prompt = gr.Textbox("Street view, realistic, detailed", label="Prompt")
73
- seed = gr.Slider(0, 10000, value=42, label="Seed")
74
- generate_btn = gr.Button("Generate Street View", variant="primary")
75
-
76
- with gr.Column(scale=1):
77
- output_image = gr.Image(label="Generated Street View")
78
- status = gr.Textbox(label="Status", interactive=False)
79
-
80
- generate_btn.click(
81
- fn=generate_image,
82
- inputs=[input_image, prompt, seed],
83
- outputs=[output_image, status]
84
- )
85
 
86
- with gr.TabItem("📊 Model Info"):
87
- gr.Markdown("""
88
- ## MagicDrive Features
89
- - Street view synthesis with 3D geometry control
90
- - Multi-view consistency
91
- - Real-time interactive GUI
92
- - Based on: [MagicDrive Paper](https://arxiv.org/abs/2310.10687)
93
- """)
94
 
95
  demo.launch()
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ # Lightweight SD pipeline (no MagicDrive complexity)
8
+ pipe = StableDiffusionPipeline.from_pretrained(
9
+ "runwayml/stable-diffusion-v1-5",
10
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
11
+ safety_checker=None
12
+ )
13
+ pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ def generate_street_view(image, prompt, seed=42):
16
+ if image is None:
17
+ return None, "Upload input image first"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Simulate MagicDrive: SD img2img
20
  torch.manual_seed(seed)
21
+ result = pipe(
22
+ prompt,
23
+ image=image,
24
+ strength=0.75,
25
+ num_inference_steps=20,
26
+ guidance_scale=7.5
27
+ ).images[0]
28
 
29
+ return result, f"Street view generated! Seed: {seed}"
 
 
 
 
 
 
 
30
 
31
+ with gr.Blocks(title="MagicDrive Demo") as demo:
32
+ gr.Markdown("# 🪄 MagicDrive-Style Street View Generator")
 
 
33
 
34
+ with gr.Row():
35
+ with gr.Column():
36
+ input_img = gr.Image(type="pil", label="Input Street Image")
37
+ prompt = gr.Textbox("realistic street view, detailed cars, buildings", label="Prompt")
38
+ seed = gr.Slider(0, 10000, value=42, label="Seed")
39
+ btn = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ output_img = gr.Image(label="Output Street View")
42
+ status = gr.Textbox(label="Status")
43
+
44
+ btn.click(generate_street_view, [input_img, prompt, seed], [output_img, status])
 
 
 
 
45
 
46
  demo.launch()