kunalpro379 commited on
Commit
b73c294
·
verified ·
1 Parent(s): 6654937

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -23
app.py CHANGED
@@ -6,7 +6,6 @@ from tensorflow.keras.models import load_model
6
  from PIL import Image
7
  import io
8
  import os
9
- import cv2
10
 
11
  app = FastAPI(title="GAN Image Generator API")
12
 
@@ -26,22 +25,21 @@ generator = load_model(model_path)
26
  def preprocess_sketch(image_bytes):
27
  """Process uploaded sketch image for the GAN"""
28
  try:
29
- # Convert bytes to PIL Image
30
  img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
31
- img = np.array(img)
32
-
33
- # Convert to grayscale
34
- gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
35
-
36
- # Resize to model's expected input size (adjust dimensions as needed)
37
- resized = cv2.resize(gray, (256, 256))
38
-
39
- # Normalize pixel values to [-1, 1] range
40
- normalized = (resized.astype(np.float32) / 127.5) - 1.0
41
-
42
- # Add batch and channel dimensions
43
- processed = np.expand_dims(normalized, axis=[0, -1])
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return processed
46
  except Exception as e:
47
  raise ValueError(f"Image processing failed: {str(e)}")
@@ -51,21 +49,23 @@ async def generate_from_sketch(file: UploadFile = File(...)):
51
  try:
52
  # Read uploaded file
53
  contents = await file.read()
54
-
55
  # Process sketch
56
  processed_sketch = preprocess_sketch(contents)
57
-
58
- # Generate image (modify this to use your actual GAN prediction)
59
  generated = generator.predict(processed_sketch)
60
- generated = (generated.squeeze() * 255).astype(np.uint8)
61
-
 
62
  # Convert to bytes
63
  img = Image.fromarray(generated)
64
  img_byte_arr = io.BytesIO()
65
  img.save(img_byte_arr, format='PNG')
66
  img_byte_arr.seek(0)
67
-
68
  return StreamingResponse(img_byte_arr, media_type="image/png")
69
-
70
  except Exception as e:
71
- raise HTTPException(status_code=400, detail=str(e))
 
 
6
  from PIL import Image
7
  import io
8
  import os
 
9
 
10
  app = FastAPI(title="GAN Image Generator API")
11
 
 
25
  def preprocess_sketch(image_bytes):
26
  """Process uploaded sketch image for the GAN"""
27
  try:
28
+ # Convert bytes to PIL Image and ensure it's RGB
29
  img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Resize to model's expected input size
32
+ img = img.resize((256, 256))
33
+
34
+ # Convert to numpy array and normalize
35
+ img_array = np.array(img).astype(np.float32) / 255.0
36
+
37
+ # Ensure the array has 3 channels
38
+ if len(img_array.shape) == 2:
39
+ img_array = np.stack((img_array,) * 3, axis=-1)
40
+
41
+ # Add batch dimension (1, 256, 256, 3)
42
+ processed = np.expand_dims(img_array, axis=0)
43
  return processed
44
  except Exception as e:
45
  raise ValueError(f"Image processing failed: {str(e)}")
 
49
  try:
50
  # Read uploaded file
51
  contents = await file.read()
52
+
53
  # Process sketch
54
  processed_sketch = preprocess_sketch(contents)
55
+
56
+ # Generate image using GAN
57
  generated = generator.predict(processed_sketch)
58
+ generated = np.clip(generated[0], 0, 1) * 255
59
+ generated = generated.astype(np.uint8)
60
+
61
  # Convert to bytes
62
  img = Image.fromarray(generated)
63
  img_byte_arr = io.BytesIO()
64
  img.save(img_byte_arr, format='PNG')
65
  img_byte_arr.seek(0)
66
+
67
  return StreamingResponse(img_byte_arr, media_type="image/png")
68
+
69
  except Exception as e:
70
+ raise HTTPException(status_code=400, detail=str(e))
71
+