Enzo8930302 commited on
Commit
a44493b
·
verified ·
1 Parent(s): 0eabd76

Upload infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. infer.py +22 -6
infer.py CHANGED
@@ -88,7 +88,14 @@ Examples:
88
  "--model", "-m",
89
  type=str,
90
  default=None,
91
- help="Path to model directory (default: uses config)"
 
 
 
 
 
 
 
92
  )
93
 
94
  parser.add_argument(
@@ -115,11 +122,20 @@ Examples:
115
  print("Byte Dream - AI Image Generator")
116
  print("="*60)
117
 
118
- generator = ByteDreamGenerator(
119
- model_path=args.model,
120
- config_path=args.config,
121
- device=args.device,
122
- )
 
 
 
 
 
 
 
 
 
123
 
124
  # Print model info
125
  info = generator.get_model_info()
 
88
  "--model", "-m",
89
  type=str,
90
  default=None,
91
+ help="Path to model directory or Hugging Face repo ID (default: uses config)"
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--hf_repo",
96
+ type=str,
97
+ default=None,
98
+ help="Hugging Face repository ID to load model from (e.g., username/repo)"
99
  )
100
 
101
  parser.add_argument(
 
122
  print("Byte Dream - AI Image Generator")
123
  print("="*60)
124
 
125
+ # Determine if loading from HF or local
126
+ if args.hf_repo:
127
+ print(f"Loading model from Hugging Face: {args.hf_repo}")
128
+ generator = ByteDreamGenerator(
129
+ hf_repo_id=args.hf_repo,
130
+ config_path=args.config,
131
+ device=args.device,
132
+ )
133
+ else:
134
+ generator = ByteDreamGenerator(
135
+ model_path=args.model,
136
+ config_path=args.config,
137
+ device=args.device,
138
+ )
139
 
140
  # Print model info
141
  info = generator.get_model_info()