make cpu compatible
Browse files- inference.py +5 -4
- requirements.txt +0 -1
inference.py
CHANGED
|
@@ -18,6 +18,7 @@ class DiffusionInference:
|
|
| 18 |
provider="hf-inference",
|
| 19 |
api_key=self.api_key,
|
| 20 |
)
|
|
|
|
| 21 |
|
| 22 |
def text_to_image(self, prompt, model_name=None, negative_prompt=None, seed=None, **kwargs):
|
| 23 |
"""
|
|
@@ -154,17 +155,17 @@ class DiffusionInference:
|
|
| 154 |
if seed is not None:
|
| 155 |
try:
|
| 156 |
# Convert to integer and add to params
|
| 157 |
-
generator = torch.Generator(device=
|
| 158 |
except (ValueError, TypeError):
|
| 159 |
# Use random seed if conversion fails
|
| 160 |
random_seed = random.randint(0, 3999999999) # Max 32-bit integer
|
| 161 |
-
generator = torch.Generator(device=
|
| 162 |
print(f"Warning: Invalid seed value: {seed}, using random seed {random_seed} instead")
|
| 163 |
else:
|
| 164 |
# Generate random seed when none is provided
|
| 165 |
random_seed = random.randint(0, 3999999999) # Max 32-bit integer
|
| 166 |
-
generator = torch.Generator(device=
|
| 167 |
print(f"Using random seed: {random_seed}")
|
| 168 |
-
pipeline = AutoPipelineForText2Image.from_pretrained(model_name, generator=generator, torch_dtype=torch.float16).to(
|
| 169 |
image = pipeline(**kwargs).images[0]
|
| 170 |
return image
|
|
|
|
| 18 |
provider="hf-inference",
|
| 19 |
api_key=self.api_key,
|
| 20 |
)
|
| 21 |
+
self.device = torch.device("cuda" if torch.cuda else "cpu")
|
| 22 |
|
| 23 |
def text_to_image(self, prompt, model_name=None, negative_prompt=None, seed=None, **kwargs):
|
| 24 |
"""
|
|
|
|
| 155 |
if seed is not None:
|
| 156 |
try:
|
| 157 |
# Convert to integer and add to params
|
| 158 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 159 |
except (ValueError, TypeError):
|
| 160 |
# Use random seed if conversion fails
|
| 161 |
random_seed = random.randint(0, 3999999999) # Max 32-bit integer
|
| 162 |
+
generator = torch.Generator(device=self.device).manual_seed(random_seed)
|
| 163 |
print(f"Warning: Invalid seed value: {seed}, using random seed {random_seed} instead")
|
| 164 |
else:
|
| 165 |
# Generate random seed when none is provided
|
| 166 |
random_seed = random.randint(0, 3999999999) # Max 32-bit integer
|
| 167 |
+
generator = torch.Generator(device=self.device).manual_seed(random_seed)
|
| 168 |
print(f"Using random seed: {random_seed}")
|
| 169 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(model_name, generator=generator, torch_dtype=torch.float16).to(self.device)
|
| 170 |
image = pipeline(**kwargs).images[0]
|
| 171 |
return image
|
requirements.txt
CHANGED
|
@@ -8,7 +8,6 @@ torch
|
|
| 8 |
transformers
|
| 9 |
diffusers
|
| 10 |
spaces>=0.14.0
|
| 11 |
-
xformers
|
| 12 |
numpy
|
| 13 |
accelerate
|
| 14 |
sentencepiece
|
|
|
|
| 8 |
transformers
|
| 9 |
diffusers
|
| 10 |
spaces>=0.14.0
|
|
|
|
| 11 |
numpy
|
| 12 |
accelerate
|
| 13 |
sentencepiece
|