CyberRohith commited on
Commit
2480131
·
verified ·
1 Parent(s): ab227c9

Upload 4 files

Browse files

Add application files

Files changed (4) hide show
  1. LLM_pipeline.py +31 -0
  2. app.py +41 -0
  3. model_loading.py +83 -0
  4. requirements.txt +14 -0
LLM_pipeline.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from model_loading import GenerationSession
4
+
5
+ def prompt_enhancer(user_prompt: str) -> str:
6
+ try:
7
+ response = requests.post(
8
+ "http://localhost:11434/api/generate",
9
+ json={
10
+ "model": "mistral",
11
+ "format": "json",
12
+ "prompt": (
13
+ f"[INST] You are an image generation prompt engineer. "
14
+ f"Rewrite this prompt to be vivid and detailed, under 60 words. "
15
+ f"Return ONLY the rewritten prompt, nothing else.\n\n"
16
+ f"Prompt: {user_prompt} [/INST]"
17
+ ),
18
+ "stream": False
19
+ },
20
+ timeout=60
21
+ )
22
+ except requests.exceptions.ConnectionError:
23
+ print("Warning: Could not connect to local Ollama.")
24
+ return user_prompt
25
+ return response.json()["response"].strip()
26
+
27
+ def smart_generate(user_prompt: str, session: GenerationSession, strength: float = 0.45):
28
+ enhanced = prompt_enhancer(user_prompt)
29
+
30
+ image = session.Generate(enhanced, strength=strength)
31
+ return image[0], enhanced
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ if not hasattr(torch, 'float8_e8m0fnu'):
4
+ torch.float8_e8m0fnu = torch.float16
5
+ from LLM_pipeline import smart_generate
6
+ from model_loading import GenerationSession
7
+ import time
8
+
9
+ model_id = "runwayml/stable-diffusion-v1-5"
10
+ session = GenerationSession(model_id)
11
+
12
+ def ui_handler(user_prompt):
13
+ start_time = time.time()
14
+ image_list, enhanced_text = smart_generate(user_prompt, session, strength=0.45)
15
+ final_image = image_list if isinstance(image_list, list) else image_list
16
+ end_time = time.time()
17
+ print(f"Image generation time: {end_time:.2f}s")
18
+
19
+ return final_image, enhanced_text, f"Total generation time: {end_time - start_time:.2f}s"
20
+
21
+ def ui_reset():
22
+ session.reset()
23
+ return None, "Session cleared. Next generation will be a brand new Base Image.", "Session reset. Next generation will be a brand new Base Image."
24
+
25
+ with gr.Blocks(title="Active Image Generator", theme=gr.Theme.from_hub("Respair/Shiki")) as demo:
26
+ gr.Markdown("## Active Image Generator\n\nEnter a prompt to generate or modify an image. Each new prompt will build upon the previous image, creating a dynamic and evolving visual experience. Use the reset button to start fresh with a new base image.")
27
+
28
+ with gr.Row():
29
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Describe the image you want to create or modify...")
30
+ generate_button = gr.Button("Generate", variant="primary")
31
+ reset_button = gr.Button("Reset Session", variant="secondary")
32
+
33
+ with gr.Column():
34
+ output_image = gr.Image(label="Generated Image")
35
+ enhanced_prompt = gr.Textbox(label="Enhanced Prompt", interactive=False)
36
+
37
+ generate_button.click(fn=ui_handler, inputs=prompt_input, outputs=[output_image, enhanced_prompt, gr.Textbox(label="Generation Time", interactive=False)])
38
+ reset_button.click(fn=ui_reset, inputs=None, outputs=[output_image, enhanced_prompt, gr.Textbox(label="Generation Time", interactive=False)])
39
+
40
+ if __name__ == "__main__":
41
+ demo.launch()
model_loading.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ import ollama
4
+ import torch
5
+ from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, LCMScheduler, AutoPipelineForText2Image
6
+ import time
7
+
8
+ model_id = "simianluo/lcm_dreamshaper_v7"
9
+
10
+ class GenerationSession:
11
+ def __init__(self, model_id):
12
+ self.model_id = model_id
13
+ self.txt2img_pipeline = None
14
+ self.img2img_pipeline = None
15
+ self.current_image = None
16
+ self.current_prompt = None
17
+ self._initialize_pipelines()
18
+
19
+ def _initialize_pipelines(self):
20
+ print("initializing pipelines...")
21
+
22
+ self.txt2img_pipeline = DiffusionPipeline.from_pretrained(
23
+ model_id,
24
+ torch_dtype = torch.float16,
25
+ safety_checker = None
26
+ )
27
+
28
+ self.txt2img_pipeline.scheduler = LCMScheduler.from_config(self.txt2img_pipeline.scheduler.config)
29
+ self.txt2img_pipeline.to("cuda")
30
+ self.txt2img_pipeline.enable_attention_slicing()
31
+ self.txt2img_pipeline.enable_vae_slicing()
32
+
33
+ # self.txt2img_pipeline.unet = torch.compile(
34
+ # self.txt2img_pipeline.unet,
35
+ # mode = "reduce-overhead",
36
+ # fullgraph = True
37
+ #)
38
+ print("Text 2 image pipeline loaded and compiled.")
39
+
40
+
41
+ self.img2img_pipeline = AutoPipelineForImage2Image.from_pipe(self.txt2img_pipeline)
42
+ print("Image 2 image pipeline loaded (shared weights).")
43
+
44
+ def GeneratingBaseImage(self, prompt: str, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
45
+ start = time.time()
46
+ image = self.txt2img_pipeline(
47
+ prompt = prompt,
48
+ negative_prompt= negative_prompt,
49
+ num_inference_steps = 4,
50
+ guidance_scale = 1.0,
51
+ height = 512,
52
+ width = 512
53
+ ).images
54
+ print(f"Text to image generated in [{time.time() - start:.2f}s]")
55
+ return image
56
+
57
+ def GeneratingVariationImage(self, prompt: str, reference_image: Image.Image, strength: float = 0.5, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
58
+ start = time.time()
59
+ image = self.img2img_pipeline(
60
+ prompt = prompt,
61
+ image = reference_image,
62
+ strength = strength,
63
+ num_inference_steps = 4,
64
+ guidance_scale = 1.0,
65
+ negative_prompt = negative_prompt
66
+ ).images
67
+ print(f"Image to image generated in [{time.time() - start:.2f}s]")
68
+ return image
69
+
70
+ def Generate(self, new_prompt: str, strength: float = 0.5):
71
+ if self.current_image is None:
72
+ self.current_image = self.GeneratingBaseImage(new_prompt)
73
+ else:
74
+ self.current_image = self.GeneratingVariationImage(new_prompt, self.current_image, strength)
75
+
76
+ self.current_prompt = new_prompt
77
+ return self.current_image
78
+
79
+ def reset(self):
80
+ self.current_image = None
81
+ self.current_prompt = None
82
+ print("Session reset. Ready for new generation.")
83
+
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core UI Framework
2
+ gradio>=4.0.0
3
+
4
+ # Deep Learning & Model Inference Frameworks
5
+ torch
6
+ torchvision
7
+ transformers
8
+ diffusers>=0.25.0
9
+ accelerate>=0.26.0
10
+
11
+ # Image Processing and Utilities
12
+ Pillow
13
+ numpy
14
+ huggingface_hub