|
|
""" |
|
|
Trouter-Imagine-1 Image Generator |
|
|
Apache 2.0 License |
|
|
A Gradio-based image generation interface using OpenTrouter/Trouter-Imagine-1 |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import requests |
|
|
import io |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import json |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/OpenTrouter/Trouter-Imagine-1" |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
class ImageGenerator: |
|
|
"""Main class for handling image generation requests""" |
|
|
|
|
|
def __init__(self, api_url, token): |
|
|
self.api_url = api_url |
|
|
self.headers = {"Authorization": f"Bearer {token}"} if token else {} |
|
|
self.generation_history = [] |
|
|
|
|
|
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, |
|
|
num_inference_steps=30, guidance_scale=7.5, seed=-1): |
|
|
""" |
|
|
Generate an image from a text prompt |
|
|
|
|
|
Args: |
|
|
prompt: Text description of the desired image |
|
|
negative_prompt: Things to avoid in the generation |
|
|
width: Image width in pixels |
|
|
height: Image height in pixels |
|
|
num_inference_steps: Number of denoising steps (more = better quality but slower) |
|
|
guidance_scale: How closely to follow the prompt (higher = more strict) |
|
|
seed: Random seed for reproducibility (-1 for random) |
|
|
|
|
|
Returns: |
|
|
PIL Image object or error message |
|
|
""" |
|
|
try: |
|
|
|
|
|
payload = { |
|
|
"inputs": prompt, |
|
|
"parameters": { |
|
|
"negative_prompt": negative_prompt, |
|
|
"width": width, |
|
|
"height": height, |
|
|
"num_inference_steps": num_inference_steps, |
|
|
"guidance_scale": guidance_scale, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if seed != -1: |
|
|
payload["parameters"]["seed"] = seed |
|
|
|
|
|
|
|
|
response = requests.post( |
|
|
self.api_url, |
|
|
headers=self.headers, |
|
|
json=payload, |
|
|
timeout=120 |
|
|
) |
|
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
error_msg = f"Error {response.status_code}: {response.text}" |
|
|
return None, error_msg |
|
|
|
|
|
|
|
|
image = Image.open(io.BytesIO(response.content)) |
|
|
|
|
|
|
|
|
self.generation_history.append({ |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"prompt": prompt, |
|
|
"negative_prompt": negative_prompt, |
|
|
"parameters": { |
|
|
"width": width, |
|
|
"height": height, |
|
|
"steps": num_inference_steps, |
|
|
"guidance": guidance_scale, |
|
|
"seed": seed |
|
|
} |
|
|
}) |
|
|
|
|
|
success_msg = f"✅ Image generated successfully! ({width}x{height}px)" |
|
|
return image, success_msg |
|
|
|
|
|
except requests.exceptions.Timeout: |
|
|
return None, "⏱️ Request timed out. Try reducing image size or steps." |
|
|
except Exception as e: |
|
|
return None, f"❌ Error: {str(e)}" |
|
|
|
|
|
def get_history_summary(self): |
|
|
"""Get a summary of generation history""" |
|
|
if not self.generation_history: |
|
|
return "No generations yet." |
|
|
|
|
|
summary = f"Total generations: {len(self.generation_history)}\n\n" |
|
|
summary += "Recent generations:\n" |
|
|
for i, entry in enumerate(self.generation_history[-5:], 1): |
|
|
summary += f"\n{i}. {entry['timestamp']}\n" |
|
|
summary += f" Prompt: {entry['prompt'][:50]}...\n" |
|
|
|
|
|
return summary |
|
|
|
|
|
|
|
|
|
|
|
generator = ImageGenerator(API_URL, HF_TOKEN) |
|
|
|
|
|
|
|
|
def generate_wrapper(prompt, negative_prompt, width, height, steps, guidance, seed, progress=gr.Progress()): |
|
|
"""Wrapper function for Gradio interface with progress tracking""" |
|
|
if not prompt.strip(): |
|
|
return None, "⚠️ Please enter a prompt!" |
|
|
|
|
|
progress(0, desc="Initializing...") |
|
|
progress(0.3, desc="Sending request to model...") |
|
|
|
|
|
image, message = generator.generate_image( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
width=int |