Arkm20 commited on
Commit
8332ab6
·
verified ·
1 Parent(s): d29325d

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +38 -0
main.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from gradio_client import Client
4
+
5
+ # Init FastAPI app
6
+ app = FastAPI()
7
+
8
+ # Initialize the Gradio Client
9
+ client = Client("Efficient-Large-Model/SanaSprint")
10
+
11
+ # Request body schema
12
+ class GenerationRequest(BaseModel):
13
+ prompt: str
14
+ model_size: str = "1.6B"
15
+ seed: int = 0
16
+ randomize_seed: bool = True
17
+ width: int = 1024
18
+ height: int = 1024
19
+ guidance_scale: float = 4.5
20
+ num_inference_steps: int = 2
21
+
22
+ @app.post("/generate")
23
+ async def generate_image(request: GenerationRequest):
24
+ try:
25
+ result = client.predict(
26
+ prompt=request.prompt,
27
+ model_size=request.model_size,
28
+ seed=request.seed,
29
+ randomize_seed=request.randomize_seed,
30
+ width=request.width,
31
+ height=request.height,
32
+ guidance_scale=request.guidance_scale,
33
+ num_inference_steps=request.num_inference_steps,
34
+ api_name="/infer"
35
+ )
36
+ return {"result": result}
37
+ except Exception as e:
38
+ raise HTTPException(status_code=500, detail=str(e))