yukee1992 commited on
Commit
8007d5e
Β·
verified Β·
1 Parent(s): ba2cbde

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ from PIL import Image
5
+ import io
6
+ import os
7
+ from datetime import datetime
8
+ import time
9
+ import json
10
+ import uuid
11
+ import random
12
+ from fastapi import FastAPI, BackgroundTasks
13
+ from pydantic import BaseModel
14
+ import requests
15
+ from huggingface_hub import HfApi
16
+
17
+ # =============================================
18
+ # CONFIGURATION
19
+ # =============================================
20
+ HF_TOKEN = os.environ.get("HF_TOKEN")
21
+ HF_USERNAME = "yukee1992"
22
+ DATASET_NAME = "video-project-images"
23
+ DATASET_ID = f"{HF_USERNAME}/{DATASET_NAME}"
24
+
25
+ print("=" * 60)
26
+ print("πŸš€ STARTING IMAGE GENERATOR")
27
+ print("=" * 60)
28
+ print(f"πŸ“¦ HF Dataset: {DATASET_ID}")
29
+ print(f"πŸ”‘ HF Token: {'βœ… Set' if HF_TOKEN else '❌ Missing'}")
30
+
31
+ # Create backup directory
32
+ BACKUP_DIR = "generated_images_backup"
33
+ os.makedirs(BACKUP_DIR, exist_ok=True)
34
+
35
+ # Initialize FastAPI
36
+ app = FastAPI()
37
+
38
+ # Global model cache
39
+ model = None
40
+ model_lock = threading.Lock()
41
+
42
+ # =============================================
43
+ # MODEL LOADING
44
+ # =============================================
45
+ def load_model():
46
+ global model
47
+ if model is None:
48
+ with model_lock:
49
+ if model is None:
50
+ print("πŸ”„ Loading model...")
51
+ model = StableDiffusionPipeline.from_pretrained(
52
+ "runwayml/stable-diffusion-v1-5",
53
+ torch_dtype=torch.float32,
54
+ safety_checker=None
55
+ ).to("cpu")
56
+ print("βœ… Model loaded!")
57
+ return model
58
+
59
+ # Preload model at startup
60
+ load_model()
61
+
62
+ # =============================================
63
+ # HF DATASET FUNCTIONS
64
+ # =============================================
65
+ def upload_to_hf_dataset(image, project_id, scene_num):
66
+ """Upload image to HF Dataset"""
67
+ if not HF_TOKEN:
68
+ return None
69
+
70
+ try:
71
+ # Convert image to bytes
72
+ img_bytes = io.BytesIO()
73
+ image.save(img_bytes, format='PNG')
74
+ img_data = img_bytes.getvalue()
75
+
76
+ # Create filename
77
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
78
+ filename = f"scene_{scene_num:03d}_{timestamp}.png"
79
+ path_in_repo = f"data/projects/{project_id}/{filename}"
80
+
81
+ # Upload
82
+ api = HfApi(token=HF_TOKEN)
83
+ api.upload_file(
84
+ path_or_fileobj=img_data,
85
+ path_in_repo=path_in_repo,
86
+ repo_id=DATASET_ID,
87
+ repo_type="dataset"
88
+ )
89
+
90
+ url = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{path_in_repo}"
91
+ return url
92
+ except Exception as e:
93
+ print(f"❌ Upload failed: {e}")
94
+ return None
95
+
96
+ # =============================================
97
+ # IMAGE GENERATION
98
+ # =============================================
99
+ def generate_image(prompt, project_id=None, scene_num=1):
100
+ """Generate a single image"""
101
+ try:
102
+ pipe = load_model()
103
+
104
+ # Generate
105
+ image = pipe(
106
+ prompt,
107
+ num_inference_steps=25,
108
+ guidance_scale=7.5,
109
+ generator=torch.Generator(device="cpu").manual_seed(random.randint(1, 999999))
110
+ ).images[0]
111
+
112
+ # Save locally
113
+ local_path = os.path.join(BACKUP_DIR, f"{uuid.uuid4()}.png")
114
+ image.save(local_path)
115
+
116
+ # Upload to HF Dataset if project_id provided
117
+ hf_url = None
118
+ if project_id:
119
+ hf_url = upload_to_hf_dataset(image, project_id, scene_num)
120
+
121
+ return {
122
+ "image": image,
123
+ "local_path": local_path,
124
+ "hf_url": hf_url
125
+ }
126
+ except Exception as e:
127
+ print(f"❌ Generation failed: {e}")
128
+ raise
129
+
130
+ # =============================================
131
+ # API ENDPOINTS
132
+ # =============================================
133
+ class GenerateRequest(BaseModel):
134
+ prompt: str
135
+ project_id: Optional[str] = None
136
+
137
+ @app.post("/generate")
138
+ async def generate(request: GenerateRequest):
139
+ """Generate a single image"""
140
+ try:
141
+ result = generate_image(request.prompt, request.project_id)
142
+ return {
143
+ "status": "success",
144
+ "hf_url": result["hf_url"],
145
+ "local_path": result["local_path"]
146
+ }
147
+ except Exception as e:
148
+ return {"status": "error", "message": str(e)}
149
+
150
+ @app.get("/health")
151
+ async def health():
152
+ """Health check"""
153
+ return {
154
+ "status": "healthy",
155
+ "model_loaded": model is not None,
156
+ "hf_dataset": DATASET_ID if HF_TOKEN else "disabled"
157
+ }
158
+
159
+ # =============================================
160
+ # GRADIO INTERFACE (Optional UI)
161
+ # =============================================
162
+ if 'gradio' in globals():
163
+ def gradio_generate(prompt):
164
+ if not prompt:
165
+ return None
166
+ result = generate_image(prompt)
167
+ return result["image"]
168
+
169
+ iface = gr.Interface(
170
+ fn=gradio_generate,
171
+ inputs=gr.Textbox(label="Prompt", placeholder="Enter your prompt..."),
172
+ outputs=gr.Image(label="Generated Image"),
173
+ title="Image Generator",
174
+ description="Generate images with Stable Diffusion"
175
+ )
176
+
177
+ # Mount Gradio
178
+ gr.mount_gradio_app(app, iface, path="/")
179
+ else:
180
+ @app.get("/")
181
+ async def root():
182
+ return {"message": "API is running. Use /generate endpoint"}
183
+
184
+ # =============================================
185
+ # MAIN
186
+ # =============================================
187
+ if __name__ == "__main__":
188
+ import uvicorn
189
+ print("\n" + "=" * 60)
190
+ print("🌐 Server starting on port 7860")
191
+ print(f"πŸ“Š API endpoints:")
192
+ print(f" - POST /generate")
193
+ print(f" - GET /health")
194
+ print(f" - GET / (UI if gradio enabled)")
195
+ print("=" * 60)
196
+
197
+ uvicorn.run(app, host="0.0.0.0", port=7860)