txh17 commited on
Commit
c688ecf
·
verified ·
1 Parent(s): d9e2fc3

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +25 -19
model.py CHANGED
@@ -1,25 +1,31 @@
1
- from diffusers import StableDiffusionPipeline
2
- from transformers import pipeline
3
  import torch
 
 
 
4
 
5
- # 加载Stable Diffusion模型
6
- def load_stable_diffusion_model():
7
- model = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
8
- model.to("cuda" if torch.cuda.is_available() else "cpu") # 如果有GPU,使用GPU加速
9
- return model
10
 
11
- # 加载GPT模型
12
- def load_gpt_model():
13
- gpt_model = pipeline("text-generation", model="gpt2")
14
- return gpt_model
15
 
16
- # 生成图像的函数
17
- def generate_image_from_prompt(model, prompt):
18
- image = model(prompt).images[0]
19
- return image
20
 
21
- # 使用GPT生成图像的详细提示
22
- def generate_prompt(gpt_model, description):
23
- prompt = gpt_model(f"Create a detailed image prompt based on this: {description}")
24
- return prompt[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
25
 
 
1
+ import openai
 
2
  import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ import os
5
+ from dotenv import load_dotenv
6
 
7
+ # 加载环境变量
8
+ load_dotenv()
 
 
 
9
 
10
+ # 获取 API 密钥
11
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
12
+ openai.api_key = OPENAI_API_KEY
 
13
 
14
+ # 加载 Stable Diffusion 模型
15
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v-1-4-original", torch_dtype=torch.float16).to("cuda")
 
 
16
 
17
+ def generate_prompt(description):
18
+ """ 使用 OpenAI 生成图像生成的提示 """
19
+ prompt = f"Generate a detailed prompt for stable diffusion image generation based on the description: {description}"
20
+ response = openai.Completion.create(
21
+ engine="text-davinci-003", # 或其他 GPT-3 模型
22
+ prompt=prompt,
23
+ max_tokens=100
24
+ )
25
+ return response.choices[0].text.strip()
26
+
27
+ def generate_image_from_prompt(prompt):
28
+ """ 使用 Stable Diffusion 生成图像 """
29
+ image = pipe(prompt).images[0]
30
+ return image
31