UDface11jkj commited on
Commit
2154210
Β·
verified Β·
1 Parent(s): 12a9ca6

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +18 -25
src/streamlit_app.py CHANGED
@@ -4,11 +4,12 @@ import torch
4
  import os
5
  import time
6
  import tempfile
 
7
  from huggingface_hub import snapshot_download
8
 
 
9
  class ImageGenerator:
10
  def __init__(self, ae_path, dit_path, qwen2vl_model_path, max_length=640):
11
- # Initialize the model with the provided paths
12
  self.ae_path = ae_path
13
  self.dit_path = dit_path
14
  self.qwen2vl_model_path = qwen2vl_model_path
@@ -17,48 +18,43 @@ class ImageGenerator:
17
  self.load_model()
18
 
19
  def load_model(self):
20
- # Load model weights or any necessary model setup here
21
  pass
22
-
23
  def to_cuda(self):
24
- # Move model to GPU if available
25
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
- # Example: Loading your model (use actual code to load)
27
  self.model = torch.load(self.ae_path, map_location=self.device)
28
- # Additional model loading logic for your specific case
29
 
 
30
  def inference(prompt, image, seed, size_level, model):
31
- # Add model prediction logic here
32
- # Example: Pass image and prompt to the model to generate output
33
- # Modify according to your actual model's inference code
34
- result_image = image # Placeholder, replace with actual generation logic
35
- used_seed = seed if seed != -1 else int(time.time()) # Use random seed if -1
36
  return result_image, used_seed
37
 
38
- # Set page config for better UI layout
39
  st.set_page_config(page_title="Ghibli style", layout="centered")
40
  st.title("πŸ–ΌοΈ Ghibli style for Free : AI Image Editing")
41
- st.markdown("Ghibli style images with AI.")
42
 
43
- # === User Inputs ===
44
  prompt = "Turn into an illustration in Studio Ghibli style"
45
  uploaded_image = st.file_uploader("πŸ“€ Upload an Image", type=["jpg", "jpeg", "png"])
46
  seed = st.number_input("🎲 Random Seed (-1 for random)", value=-1, step=1)
47
  size_level = st.number_input("πŸ“ Size Level (minimum 512)", value=512, min_value=512, step=32)
48
-
49
  generate_button = st.button("πŸš€ Generate")
50
 
51
- # === Load Model (Cached) ===
52
  @st.cache_resource
53
  def load_model():
54
  repo = "stepfun-ai/Step1X-Edit"
55
- local_dir = "./step1x_weights"
56
- os.makedirs(local_dir, exist_ok=True)
 
57
  snapshot_download(repo_id=repo, local_dir=local_dir, local_dir_use_symlinks=False)
58
 
59
  model = ImageGenerator(
60
- ae_path=os.path.join(local_dir, 'vae.safetensors'),
61
- dit_path=os.path.join(local_dir, "step1x-edit-i1258.safetensors"),
62
  qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
63
  max_length=640
64
  )
@@ -66,12 +62,11 @@ def load_model():
66
 
67
  image_edit_model = load_model()
68
 
69
- # === Inference and Image Display ===
70
  if generate_button and uploaded_image is not None:
71
  input_image = Image.open(uploaded_image).convert("RGB")
72
- # Resize image for faster inference (adjust to your model's requirements)
73
  input_image.thumbnail((size_level, size_level))
74
-
75
  with st.spinner("πŸ”„ Generating edited image..."):
76
  start = time.time()
77
  try:
@@ -79,8 +74,6 @@ if generate_button and uploaded_image is not None:
79
  end = time.time()
80
 
81
  st.success(f"βœ… Done in {end - start:.2f} seconds β€” Seed used: {used_seed}")
82
-
83
- # Save and display the result in temporary file
84
  with tempfile.NamedTemporaryFile(dir="/tmp", delete=False, suffix=".png") as temp_file:
85
  result_image.save(temp_file.name)
86
  st.image(temp_file.name, caption="πŸ–ΌοΈ Edited Image", use_column_width=True)
 
4
  import os
5
  import time
6
  import tempfile
7
+ from pathlib import Path
8
  from huggingface_hub import snapshot_download
9
 
10
+ # === Model Wrapper Class ===
11
  class ImageGenerator:
12
  def __init__(self, ae_path, dit_path, qwen2vl_model_path, max_length=640):
 
13
  self.ae_path = ae_path
14
  self.dit_path = dit_path
15
  self.qwen2vl_model_path = qwen2vl_model_path
 
18
  self.load_model()
19
 
20
  def load_model(self):
21
+ # Dummy placeholder - replace with actual model loading logic
22
  pass
23
+
24
  def to_cuda(self):
 
25
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
26
  self.model = torch.load(self.ae_path, map_location=self.device)
27
+ # Add actual model load logic as needed
28
 
29
+ # === Inference Function ===
30
  def inference(prompt, image, seed, size_level, model):
31
+ result_image = image # Placeholder - Replace with actual inference logic
32
+ used_seed = seed if seed != -1 else int(time.time())
 
 
 
33
  return result_image, used_seed
34
 
35
+ # === Streamlit UI Setup ===
36
  st.set_page_config(page_title="Ghibli style", layout="centered")
37
  st.title("πŸ–ΌοΈ Ghibli style for Free : AI Image Editing")
38
+ st.markdown("Generate Studio Ghibli style illustrations from your image using AI.")
39
 
 
40
  prompt = "Turn into an illustration in Studio Ghibli style"
41
  uploaded_image = st.file_uploader("πŸ“€ Upload an Image", type=["jpg", "jpeg", "png"])
42
  seed = st.number_input("🎲 Random Seed (-1 for random)", value=-1, step=1)
43
  size_level = st.number_input("πŸ“ Size Level (minimum 512)", value=512, min_value=512, step=32)
 
44
  generate_button = st.button("πŸš€ Generate")
45
 
46
+ # === Cached Model Loader ===
47
  @st.cache_resource
48
  def load_model():
49
  repo = "stepfun-ai/Step1X-Edit"
50
+ local_dir = Path.home() / "step1x_weights"
51
+ local_dir.mkdir(exist_ok=True)
52
+
53
  snapshot_download(repo_id=repo, local_dir=local_dir, local_dir_use_symlinks=False)
54
 
55
  model = ImageGenerator(
56
+ ae_path=local_dir / 'vae.safetensors',
57
+ dit_path=local_dir / "step1x-edit-i1258.safetensors",
58
  qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
59
  max_length=640
60
  )
 
62
 
63
  image_edit_model = load_model()
64
 
65
+ # === Handle Generation ===
66
  if generate_button and uploaded_image is not None:
67
  input_image = Image.open(uploaded_image).convert("RGB")
 
68
  input_image.thumbnail((size_level, size_level))
69
+
70
  with st.spinner("πŸ”„ Generating edited image..."):
71
  start = time.time()
72
  try:
 
74
  end = time.time()
75
 
76
  st.success(f"βœ… Done in {end - start:.2f} seconds β€” Seed used: {used_seed}")
 
 
77
  with tempfile.NamedTemporaryFile(dir="/tmp", delete=False, suffix=".png") as temp_file:
78
  result_image.save(temp_file.name)
79
  st.image(temp_file.name, caption="πŸ–ΌοΈ Edited Image", use_column_width=True)