arevedudaa commited on
Commit
7c3f1c6
·
verified ·
1 Parent(s): 04992f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -8
app.py CHANGED
@@ -1,20 +1,42 @@
1
- from fastapi import FastAPI, Query
2
- from gradio_client import Client, handle_file
3
  import requests
 
4
 
5
  app = FastAPI()
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  @app.get("/generate")
8
  def generate_image(prompt: str, image_url: str):
9
- client = Client("yanze/PuLID-FLUX")
10
- id_image = handle_file(image_url)
 
 
 
11
  width, height = 1080, 1080 # Instagram format (fixed size)
12
- base_url = "https://yanze-pulid-flux.hf.space/file="
13
 
14
  try:
15
  result = client.predict(
16
  prompt=prompt,
17
- id_image=id_image,
18
  start_step=0,
19
  guidance=4,
20
  seed="-1",
@@ -29,8 +51,12 @@ def generate_image(prompt: str, image_url: str):
29
  api_name="/generate_image"
30
  )
31
 
 
 
 
32
  file_path = result[0]
33
- full_url = f"{base_url}{file_path}"
34
  return {"image_url": full_url}
 
35
  except Exception as e:
36
- return {"error": str(e)}
 
1
+ from fastapi import FastAPI, Query, HTTPException
2
+ from gradio_client import Client
3
  import requests
4
+ import os
5
 
6
  app = FastAPI()
7
 
8
+ HF_MODEL = "yanze/PuLID-FLUX"
9
+ BASE_URL = "https://yanze-pulid-flux.hf.space/file=" # Ensure this is correct
10
+ TEMP_DIR = "/tmp" # Folder to store temp image downloads
11
+
12
+ def download_image(image_url: str) -> str:
13
+ """Downloads image from URL and saves it locally."""
14
+ try:
15
+ response = requests.get(image_url, stream=True)
16
+ if response.status_code != 200:
17
+ raise HTTPException(status_code=400, detail="Failed to download image")
18
+
19
+ filename = os.path.join(TEMP_DIR, "input_image.jpg")
20
+ with open(filename, "wb") as f:
21
+ for chunk in response.iter_content(1024):
22
+ f.write(chunk)
23
+ return filename
24
+ except Exception as e:
25
+ raise HTTPException(status_code=500, detail=f"Error downloading image: {str(e)}")
26
+
27
  @app.get("/generate")
28
  def generate_image(prompt: str, image_url: str):
29
+ client = Client(HF_MODEL)
30
+
31
+ # Download image from URL first
32
+ image_path = download_image(image_url)
33
+
34
  width, height = 1080, 1080 # Instagram format (fixed size)
 
35
 
36
  try:
37
  result = client.predict(
38
  prompt=prompt,
39
+ id_image=image_path, # File path instead of URL
40
  start_step=0,
41
  guidance=4,
42
  seed="-1",
 
51
  api_name="/generate_image"
52
  )
53
 
54
+ if not result or not isinstance(result, list) or len(result) == 0:
55
+ raise HTTPException(status_code=500, detail="Model did not return a valid response")
56
+
57
  file_path = result[0]
58
+ full_url = f"{BASE_URL}{file_path}"
59
  return {"image_url": full_url}
60
+
61
  except Exception as e:
62
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")