Enzo8930302 commited on
Commit
c97fe2c
·
verified ·
1 Parent(s): a44493b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +23 -8
app.py CHANGED
@@ -6,26 +6,41 @@ Interactive web UI for image generation
6
  import gradio as gr
7
  from bytedream.generator import ByteDreamGenerator
8
  import torch
 
9
 
10
 
11
  # Initialize generator
12
  print("Loading Byte Dream model...")
 
 
 
 
 
13
  try:
14
- generator = ByteDreamGenerator(
15
- model_path="./models/bytedream",
16
- config_path="config.yaml",
17
- device="cpu",
18
- )
 
 
 
 
 
 
 
 
 
19
  print("✓ Model loaded successfully!")
20
  except Exception as e:
21
  print(f"⚠ Warning: Could not load model: {e}")
22
  print(" Please train the model first using: python train.py")
23
  print(" Or download pretrained weights from Hugging Face.")
24
  print("")
25
- print(" To use a model from Hugging Face, run:")
26
- print(" python infer.py --prompt 'your prompt' --model 'username/repo_name'")
27
  print("")
28
- print("Starting in demo mode with random initialization...")
29
  generator = None
30
 
31
 
 
6
  import gradio as gr
7
  from bytedream.generator import ByteDreamGenerator
8
  import torch
9
+ import os
10
 
11
 
12
  # Initialize generator
13
  print("Loading Byte Dream model...")
14
+
15
+ # Check if we should load from Hugging Face
16
+ HF_REPO_ID = os.getenv("HF_REPO_ID", None)
17
+ MODEL_PATH = os.getenv("MODEL_PATH", "./models/bytedream")
18
+
19
  try:
20
+ if HF_REPO_ID:
21
+ print(f"Loading model from Hugging Face: {HF_REPO_ID}...")
22
+ generator = ByteDreamGenerator(
23
+ hf_repo_id=HF_REPO_ID,
24
+ config_path="config.yaml",
25
+ device="cpu",
26
+ )
27
+ else:
28
+ print(f"Loading model from local path: {MODEL_PATH}...")
29
+ generator = ByteDreamGenerator(
30
+ model_path=MODEL_PATH,
31
+ config_path="config.yaml",
32
+ device="cpu",
33
+ )
34
  print("✓ Model loaded successfully!")
35
  except Exception as e:
36
  print(f"⚠ Warning: Could not load model: {e}")
37
  print(" Please train the model first using: python train.py")
38
  print(" Or download pretrained weights from Hugging Face.")
39
  print("")
40
+ print(" To use a model from Hugging Face, set environment variable:")
41
+ print(" HF_REPO_ID=username/repo_name")
42
  print("")
43
+ print("Starting in demo mode...")
44
  generator = None
45
 
46