File size: 8,463 Bytes
bdb7403
 
dbf9089
 
 
bdb7403
dbf9089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdb7403
 
 
 
 
dbf9089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdb7403
 
 
 
 
dbf9089
 
bdb7403
 
 
 
dbf9089
 
bdb7403
dbf9089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdb7403
dbf9089
bdb7403
dbf9089
 
bdb7403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf9089
 
bdb7403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf9089
 
 
 
 
bdb7403
 
 
 
 
 
 
 
 
 
 
 
 
dbf9089
 
 
 
 
 
bdb7403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf9089
 
 
 
 
 
 
 
 
 
bdb7403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf9089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdb7403
dbf9089
 
 
 
 
 
 
 
 
 
 
 
 
bdb7403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf9089
 
 
 
 
bdb7403
 
12eff42
 
bdb7403
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
import time
import logging
import os
import sys
import subprocess
from contextlib import asynccontextmanager
from typing import List
from enum import Enum
from pydantic import BaseModel

# Install required packages
def install_packages():
    """Install required packages using pip"""
    packages = [
        "fastapi",
        "uvicorn[standard]",
        "pillow",
        "huggingface_hub",
        "pydantic"
    ]
    
    for package in packages:
        try:
            # Check if package is already installed
            if package == "uvicorn[standard]":
                __import__("uvicorn")
            elif package == "huggingface_hub":
                __import__("huggingface_hub")
            else:
                __import__(package.replace("-", "_"))
            print(f"{package} already installed")
        except ImportError:
            print(f"Installing {package}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Install packages before importing
install_packages()

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

# Define models directly in the file
class ResponseFormat(str, Enum):
    URL = "url"
    B64_JSON = "b64_json"

class ImageGenerationRequest(BaseModel):
    prompt: str
    model: str = "dall-e-3"
    n: int = 1
    size: str = "1024x1024"
    quality: str = "standard"
    response_format: ResponseFormat = ResponseFormat.URL

class ImageData(BaseModel):
    url: str = None
    b64_json: str = None
    revised_prompt: str = None

class ImageGenerationResponse(BaseModel):
    created: int
    data: List[ImageData]

class ErrorResponse(BaseModel):
    error: dict

class ModelInfo(BaseModel):
    id: str
    created: int
    owned_by: str

class ModelsResponse(BaseModel):
    data: List[ModelInfo]

# Import the modified image generator
from image_generator import ImageGenerator

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global image generator instance
image_generator = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Application lifespan management"""
    global image_generator
    
    logger.info("Starting TTI Frame API...")
    
    # Initialize image generator
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        logger.warning("HF_TOKEN environment variable not set. Image generation may fail.")
    
    image_generator = ImageGenerator(hf_token=hf_token)
    
    # Set base URL for serving images
    base_url = os.getenv("BASE_URL", "http://localhost:8000")
    image_generator.set_config(base_url=base_url)
    
    # Mount the temporary directory for static files
    app.mount("/images", StaticFiles(directory=image_generator.output_dir), name="images")
    
    logger.info(f"Image generator initialized with output directory: {image_generator.output_dir}")
    
    yield
    
    logger.info("Shutting down TTI Frame API...")
    if image_generator:
        image_generator.cleanup()

# Create FastAPI app
app = FastAPI(
    title="TTI Frame - OpenAI Compatible Text-to-Image API",
    description="A FastAPI wrapper providing OpenAI-compatible endpoints for text-to-image generation",
    version="1.0.0",
    lifespan=lifespan
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Configure as needed
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
async def root():
    """Root endpoint"""
    return {
        "message": "TTI Frame - OpenAI Compatible Text-to-Image API",
        "version": "1.0.0",
        "docs": "/docs",
        "output_dir": image_generator.output_dir if image_generator else "Not initialized"
    }

@app.get("/v1/models", response_model=ModelsResponse)
async def list_models():
    """List available models (OpenAI compatible)"""
    models = [
        ModelInfo(
            id="dall-e-3",
            created=1677649963,
            owned_by="tti-frame"
        ),
        ModelInfo(
            id="dall-e-2", 
            created=1677649963,
            owned_by="tti-frame"
        ),
        ModelInfo(
            id="black-forest-labs/flux-schnell",
            created=1677649963,
            owned_by="tti-frame"
        )
    ]
    
    return ModelsResponse(data=models)

@app.post("/v1/images/generations", response_model=ImageGenerationResponse)
async def create_image(request: ImageGenerationRequest):
    """
    Generate images from text prompts (OpenAI compatible)
    
    Creates images based on a text prompt using advanced diffusion models.
    Supports various sizes, qualities, and response formats.
    """
    if not image_generator:
        raise HTTPException(
            status_code=500,
            detail="Image generator not initialized. Check HF_TOKEN environment variable."
        )
    
    try:
        logger.info(f"Received image generation request: {request.prompt[:50]}...")
        
        # Validate request
        if not request.prompt or not request.prompt.strip():
            raise HTTPException(
                status_code=400, 
                detail="Prompt cannot be empty"
            )
        
        if len(request.prompt) > 4000:
            raise HTTPException(
                status_code=400,
                detail="Prompt too long. Maximum 4000 characters allowed."
            )
        
        # Map OpenAI model names to HuggingFace models
        model_mapping = {
            "dall-e-3": "black-forest-labs/flux-schnell",
            "dall-e-2": "black-forest-labs/flux-schnell",
        }
        
        # Update request model if needed
        if request.model in model_mapping:
            request.model = model_mapping[request.model]
        
        # Generate images
        image_data = await image_generator.generate_images(request)
        
        response = ImageGenerationResponse(
            created=int(time.time()),
            data=image_data
        )
        
        logger.info(f"Successfully generated {len(image_data)} images")
        return response
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Image generation failed: {e}")
        raise HTTPException(
            status_code=500,
            detail=f"Image generation failed: {str(e)}"
        )

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy", 
        "timestamp": int(time.time()),
        "generator_initialized": image_generator is not None,
        "output_dir": image_generator.output_dir if image_generator else None
    }

@app.get("/config")
async def get_config():
    """Get current configuration"""
    if not image_generator:
        return {"error": "Image generator not initialized"}
    
    return {
        "output_dir": image_generator.output_dir,
        "base_url": image_generator.base_url,
        "default_model": image_generator.default_model,
        "hf_token_set": bool(image_generator.hf_token)
    }

@app.post("/config")
async def update_config(hf_token: str = None, base_url: str = None, default_model: str = None):
    """Update configuration"""
    if not image_generator:
        raise HTTPException(status_code=500, detail="Image generator not initialized")
    
    image_generator.set_config(
        hf_token=hf_token,
        base_url=base_url,
        default_model=default_model
    )
    
    return {"message": "Configuration updated successfully"}

@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
    """Global exception handler"""
    logger.error(f"Unhandled exception: {exc}")
    return JSONResponse(
        status_code=500,
        content=ErrorResponse(
            error={
                "message": "Internal server error",
                "type": "server_error",
                "code": "internal_error"
            }
        ).dict()
    )

if __name__ == "__main__":
    # Set environment variables if not already set
    if not os.getenv("HF_TOKEN"):
        print("Warning: HF_TOKEN environment variable not set.")
        print("Please set it with: export HF_TOKEN=your_huggingface_token")
    
    uvicorn.run(
        "main:app",
        host="0.0.0.0",
        port=8000,
        reload=True,
        log_level="info"
    )