File size: 5,046 Bytes
557227d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3
# This script demonstrates how to test your Hugging Face Inference Endpoint
# Replace the API_TOKEN and API_URL with your actual values

import requests
import json
import base64
from PIL import Image
import io
import argparse
import os

def test_inference_endpoint(api_token, api_url, prompt, negative_prompt=None, 
                            seed=None, inference_steps=30, guidance_scale=7,
                            width=1024, height=768, output_dir="generated_images"):
    """
    Test a Hugging Face Inference Endpoint for image generation.
    
    Args:
        api_token (str): Your Hugging Face API token
        api_url (str): The URL of your inference endpoint
        prompt (str): The text prompt for image generation
        negative_prompt (str, optional): Negative prompt to guide generation
        seed (int, optional): Random seed for reproducibility
        inference_steps (int): Number of inference steps
        guidance_scale (float): Guidance scale for generation
        width (int): Image width
        height (int): Image height
        output_dir (str): Directory to save generated images
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Headers for the request
    headers = {
        "Authorization": f"Bearer {api_token}",
        "Content-Type": "application/json"
    }
    
    # Build parameters dictionary with provided values
    parameters = {
        "width": width,
        "height": height,
        "inference_steps": inference_steps,
        "guidance_scale": guidance_scale
    }
    
    # Add optional parameters if provided
    if negative_prompt:
        parameters["negative_prompt"] = negative_prompt
    if seed:
        parameters["seed"] = seed
    
    # Request payload
    payload = {
        "inputs": prompt,
        "parameters": parameters
    }
    
    print(f"Sending request to {api_url}...")
    print(f"Prompt: '{prompt}'")
    
    try:
        # Send the request
        response = requests.post(api_url, headers=headers, json=payload)
        
        # Check for errors
        if response.status_code != 200:
            print(f"Error: {response.status_code} - {response.text}")
            return
        
        # Parse the response
        result = response.json()
        
        # Check for error in the response
        if isinstance(result, dict) and "error" in result:
            print(f"API Error: {result['error']}")
            return
        
        # Extract the generated image and seed
        if isinstance(result, list) and len(result) > 0:
            item = result[0]
            if "generated_image" in item:
                # Convert the base64-encoded image to a PIL Image
                image_bytes = base64.b64decode(item["generated_image"])
                image = Image.open(io.BytesIO(image_bytes))
                
                # Create a filename based on the prompt and seed
                used_seed = item.get("seed", "unknown_seed")
                filename = f"{output_dir}/generated_{used_seed}.png"
                
                # Save the image
                image.save(filename)
                print(f"Image saved to {filename}")
                print(f"Seed: {used_seed}")
                
                return image
            else:
                print("Response doesn't contain 'generated_image' field")
        else:
            print("Unexpected response format:", result)
    
    except Exception as e:
        print(f"Error: {str(e)}")

if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Test Hugging Face Inference Endpoints for image generation")
    parser.add_argument("--token", required=True, help="Your Hugging Face API token")
    parser.add_argument("--url", required=True, help="URL of your inference endpoint")
    parser.add_argument("--prompt", required=True, help="Text prompt for image generation")
    parser.add_argument("--negative_prompt", help="Negative prompt")
    parser.add_argument("--seed", type=int, help="Random seed for reproducibility")
    parser.add_argument("--steps", type=int, default=30, help="Number of inference steps")
    parser.add_argument("--guidance", type=float, default=7, help="Guidance scale")
    parser.add_argument("--width", type=int, default=1024, help="Image width")
    parser.add_argument("--height", type=int, default=768, help="Image height")
    parser.add_argument("--output_dir", default="generated_images", help="Directory to save generated images")
    
    args = parser.parse_args()
    
    # Call the test function with provided arguments
    test_inference_endpoint(
        api_token=args.token,
        api_url=args.url,
        prompt=args.prompt,
        negative_prompt=args.negative_prompt,
        seed=args.seed,
        inference_steps=args.steps,
        guidance_scale=args.guidance,
        width=args.width,
        height=args.height,
        output_dir=args.output_dir
    )