islasher commited on
Commit
444c2a1
·
verified ·
1 Parent(s): 2d2a8dc

Update ImageGenerator.py

Browse files
Files changed (1) hide show
  1. ImageGenerator.py +16 -4
ImageGenerator.py CHANGED
@@ -124,13 +124,25 @@ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float
124
  def generate_image(prompt):
125
  # Get the part after the dash
126
  print('PROMPT TO INTRODUCE TO STABLDIFFUSION:', prompt)
127
- if "-" in prompt:
128
- prompt_to_send = prompt.split("-", 1)[-1].strip()
129
- else:
130
- prompt_to_send = prompt.strip()
 
 
131
  # prompt = "a photo of an astronaut riding a horse on mars"
132
  image = pipe(prompt_to_send).images[0]
133
  return image
 
 
 
 
 
 
 
 
 
 
134
  # image.save("astronaut_rides_horse.png")
135
 
136
 
 
124
  def generate_image(prompt):
125
  # Get the part after the dash
126
  print('PROMPT TO INTRODUCE TO STABLDIFFUSION:', prompt)
127
+ prompt_to_send = extract_assistant_response(prompt)
128
+ print('INTRODUCED: ', prompt_to_send)
129
+ # if "-" in prompt:
130
+ # prompt_to_send = prompt.split("-", 1)[-1].strip()
131
+ # else:
132
+ # prompt_to_send = prompt.strip()
133
  # prompt = "a photo of an astronaut riding a horse on mars"
134
  image = pipe(prompt_to_send).images[0]
135
  return image
136
+
137
+ def extract_assistant_response(decoded_text):
138
+ if "<|im_start|>assistant" in decoded_text:
139
+ part = decoded_text.split("<|im_start|>assistant", 1)[1]
140
+ part = part.split("<|im_end|>", 1)[0]
141
+ return part.strip()
142
+ return decoded_text.strip()
143
+
144
+
145
+
146
  # image.save("astronaut_rides_horse.png")
147
 
148