CyberRohith commited on
Commit
beeaba3
·
verified ·
1 Parent(s): 1040cdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -17
app.py CHANGED
@@ -1,32 +1,79 @@
1
- import gradio as gr
2
- import torch
3
- import time
4
 
 
5
  if not hasattr(torch, 'float8_e8m0fnu'):
6
- torch.float8_e8m0fnu = torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from LLM_pipeline import smart_generate
9
  from model_loading import GenerationSession
10
 
11
  model_id = "simianluo/lcm_dreamshaper_v7"
12
  session = GenerationSession(model_id)
13
 
14
- if session.device == "cpu":
15
- print("Applying High-Quality CPU sequential offload patches...")
16
- try:
17
-
18
- session.txt2img_pipeline.enable_sequential_cpu_offload()
19
- session.img2img_pipeline.enable_sequential_cpu_offload()
20
- except Exception as e:
21
- print(f"Note: Multi-pipeline offload optimized via sub-modules: {e}")
22
 
23
  def ui_handler(user_prompt):
24
  start_time = time.time()
25
- image_list, enhanced_text = smart_generate(user_prompt, session, strength=0.45)
26
- final_image = image_list if isinstance(image_list, list) else image_list
27
- end_time = time.time()
28
- print(f"Image generation time: {end_time:.2f}s")
 
 
 
29
 
 
 
 
 
 
30
  return final_image, enhanced_text, f"Total generation time: {end_time - start_time:.2f}s"
31
 
32
  def ui_reset():
@@ -42,7 +89,7 @@ with gr.Blocks(title="Active Image Generator", theme=gr.Theme.from_hub("Respair/
42
  reset_button = gr.Button("Reset Session", variant="secondary")
43
 
44
  with gr.Column():
45
- output_image = gr.Image(label="Generated Image")
46
  enhanced_prompt = gr.Textbox(label="Enhanced Prompt", interactive=False)
47
  time_output = gr.Textbox(label="Generation Time", interactive=False)
48
 
 
1
+ import sys
2
+ from unittest.mock import MagicMock
 
3
 
4
+ import torch
5
  if not hasattr(torch, 'float8_e8m0fnu'):
6
+ torch.float8_e8m0fnu = torch.float16
7
+
8
+ import io
9
+ from PIL import Image
10
+ from huggingface_hub import InferenceClient
11
+
12
+ class APIPipelineMock:
13
+ def __init__(self):
14
+ self.client = InferenceClient()
15
+ self.model_id = "simianluo/lcm_dreamshaper_v7"
16
+ self.scheduler = MagicMock()
17
+
18
+ def to(self, device): return self
19
+ def enable_attention_slicing(self): pass
20
+ def enable_vae_slicing(self): pass
21
+
22
+ def __call__(self, prompt, image=None, strength=0.5, **kwargs):
23
+ if image is None:
24
+ img = self.client.text_to_image(prompt, model=self.model_id)
25
+ else:
26
+ img_byte_arr = io.BytesIO()
27
+ image.save(img_byte_arr, format='PNG')
28
+ img_bytes = img_byte_arr.getvalue()
29
+ img = self.client.image_to_image(img_bytes, prompt=prompt, model=self.model_id, strength=strength)
30
+
31
+ class DiffusersOutput:
32
+ def __init__(self, images): self.images = images
33
+ return DiffusersOutput(images=img)
34
+
35
+ mock_pipe = APIPipelineMock()
36
+
37
+ class MockDiffusionPipeline:
38
+ @classmethod
39
+ def from_pretrained(cls, *args, **kwargs): return mock_pipe
40
 
41
+ class MockAutoPipeline:
42
+ @classmethod
43
+ def from_pipe(cls, *args, **kwargs): return mock_pipe
44
+
45
+
46
+ import diffusers
47
+ diffusers.DiffusionPipeline = MockDiffusionPipeline
48
+ diffusers.AutoPipelineForImage2Image = MockAutoPipeline
49
+
50
+
51
+ import gradio as gr
52
+ import time
53
  from LLM_pipeline import smart_generate
54
  from model_loading import GenerationSession
55
 
56
  model_id = "simianluo/lcm_dreamshaper_v7"
57
  session = GenerationSession(model_id)
58
 
59
+ session.txt2img_pipeline = mock_pipe
60
+ session.img2img_pipeline = mock_pipe
 
 
 
 
 
 
61
 
62
  def ui_handler(user_prompt):
63
  start_time = time.time()
64
+
65
+ if session.current_image is None:
66
+ raw_output = mock_pipe(prompt=user_prompt)
67
+ final_image = raw_output.images
68
+ else:
69
+ raw_output = mock_pipe(prompt=user_prompt, image=session.current_image, strength=0.45)
70
+ final_image = raw_output.images
71
 
72
+ session.current_image = final_image
73
+ session.current_prompt = user_prompt
74
+ enhanced_text = user_prompt
75
+
76
+ end_time = time.time()
77
  return final_image, enhanced_text, f"Total generation time: {end_time - start_time:.2f}s"
78
 
79
  def ui_reset():
 
89
  reset_button = gr.Button("Reset Session", variant="secondary")
90
 
91
  with gr.Column():
92
+ output_image = gr.Image(label="Generated Image", type="pil")
93
  enhanced_prompt = gr.Textbox(label="Enhanced Prompt", interactive=False)
94
  time_output = gr.Textbox(label="Generation Time", interactive=False)
95