txh1 / model.py
txh17's picture
Update model.py
6a814d3 verified
raw
history blame contribute delete
949 Bytes
import openai
import torch
from diffusers import StableDiffusionPipeline
import os
from dotenv import load_dotenv
# εŠ θ½½ηŽ―ε’ƒε˜ι‡
load_dotenv()
# θŽ·ε– API ε―†ι’₯
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
openai.api_key = OPENAI_API_KEY
# 加载 Stable Diffusion ζ¨‘εž‹
pipe = StableDiffusionPipeline.from_pretrained("SfinOe/stable-diffusion-v1.5", torch_dtype=torch.float16).to("cpu")
def generate_prompt(description):
""" 使用 OpenAI η”Ÿζˆε›Ύεƒη”Ÿζˆηš„ζη€Ί """
prompt = f"Generate a detailed prompt for stable diffusion image generation based on the description: {description}"
response = openai.Completion.create(
engine="text-davinci-003", # ζˆ–ε…Άδ»– GPT-3 ζ¨‘εž‹
prompt=prompt,
max_tokens=100
)
return response.choices[0].text.strip()
def generate_image_from_prompt(prompt):
""" 使用 Stable Diffusion η”Ÿζˆε›Ύεƒ """
image = pipe(prompt).images[0]
return image