jree423 commited on
Commit
21f47d0
·
verified ·
1 Parent(s): 7b38ba8

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +45 -9
app.py CHANGED
@@ -1,13 +1,49 @@
1
 
2
- from fastapi import FastAPI, Request
3
- from handler import EndpointHandler
4
  import os
 
 
 
 
5
 
6
- app = FastAPI()
7
- handler = EndpointHandler(os.getcwd())
8
 
9
- @app.post("/")
10
- async def process_request(request: Request):
11
- json_data = await request.json()
12
- return handler(json_data)
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
 
 
2
  import os
3
+ import sys
4
+ import json
5
+ import torch
6
+ from model import pipeline
7
 
8
+ # Initialize the model
9
+ model = pipeline()
10
 
11
+ def run(prompt, negative_prompt="", num_paths=96, guidance_scale=7.5, seed=42):
12
+ """Run the model with the given parameters."""
13
+ return model(
14
+ prompt=prompt,
15
+ negative_prompt=negative_prompt,
16
+ num_paths=int(num_paths),
17
+ guidance_scale=float(guidance_scale),
18
+ seed=int(seed)
19
+ )
20
+
21
+ def parse_args():
22
+ """Parse command line arguments."""
23
+ if len(sys.argv) > 1:
24
+ # Command line arguments
25
+ prompt = sys.argv[1]
26
+ negative_prompt = sys.argv[2] if len(sys.argv) > 2 else ""
27
+ num_paths = int(sys.argv[3]) if len(sys.argv) > 3 else 96
28
+ guidance_scale = float(sys.argv[4]) if len(sys.argv) > 4 else 7.5
29
+ seed = int(sys.argv[5]) if len(sys.argv) > 5 else 42
30
+ else:
31
+ # Read from stdin (for API)
32
+ data = json.loads(sys.stdin.read())
33
+ prompt = data.get("prompt", "")
34
+ negative_prompt = data.get("negative_prompt", "")
35
+ num_paths = int(data.get("num_paths", 96))
36
+ guidance_scale = float(data.get("guidance_scale", 7.5))
37
+ seed = int(data.get("seed", 42))
38
+
39
+ return prompt, negative_prompt, num_paths, guidance_scale, seed
40
+
41
+ if __name__ == "__main__":
42
+ # Parse arguments
43
+ prompt, negative_prompt, num_paths, guidance_scale, seed = parse_args()
44
+
45
+ # Run the model
46
+ result = run(prompt, negative_prompt, num_paths, guidance_scale, seed)
47
+
48
+ # Print the result as JSON
49
+ print(json.dumps(result))