Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import sqlite3 | |
| import replicate | |
| import argparse | |
| import requests | |
| from datetime import datetime | |
| from fastapi.security import HTTPBasic, HTTPBasicCredentials | |
| from starlette import status # Für HTTP-Statuscodes | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Form, Query,Response | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| import uvicorn | |
| from asyncio import gather, Semaphore, create_task | |
| from mistralai import Mistral | |
| from contextlib import contextmanager | |
| from io import BytesIO | |
| import zipfile | |
| import sys | |
| print(f"Arguments: {sys.argv}") | |
| token = os.getenv("HF_TOKEN") | |
| api_key = os.getenv("MISTRAL_API_KEY") | |
| agent_id = os.getenv("MISTRAL_FLUX_AGENT") | |
| # ANSI Escape Codes für farbige Ausgabe (kann entfernt werden, falls nicht benötigt) | |
| HEADER = "\033[38;2;255;255;153m" | |
| TITLE = "\033[38;2;255;255;153m" | |
| MENU = "\033[38;2;255;165;0m" | |
| SUCCESS = "\033[38;2;153;255;153m" | |
| ERROR = "\033[38;2;255;69;0m" | |
| MAIN = "\033[38;2;204;204;255m" | |
| SPEAKER1 = "\033[38;2;173;216;230m" | |
| SPEAKER2 = "\033[38;2;255;179;102m" | |
| RESET = "\033[0m" | |
| DOWNLOAD_DIR = "/home/user/app/flux-pics" # Pfad zu deinen Bildern (sollte korrekt sein) | |
| DATABASE_PATH = "/home/user/app/flux_logs.db" # Datenbank-Pfad | |
| TIMEOUT_DURATION = 900 # Timeout-Dauer in Sekunden (scheint angemessen) | |
| # WICHTIG: Stelle sicher, dass dieses Verzeichnis existiert und die Bilder enthält. | |
| IMAGE_STORAGE_PATH = DOWNLOAD_DIR | |
| app = FastAPI() | |
| security = HTTPBasic() | |
| # Umgebungsvariablen für Benutzername und Passwort | |
| USERNAME = os.getenv("CF_USER", "default_user") | |
| PASSWORD = os.getenv("CF_PASSWORD", "default_password") | |
| # StaticFiles Middleware hinzufügen (korrekt und wichtig!) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| app.mount("/flux-pics", StaticFiles(directory=IMAGE_STORAGE_PATH), name="flux-pics") | |
| templates = Jinja2Templates(directory="templates") | |
| # Datenbank-Hilfsfunktionen (sehen gut aus) | |
| def get_db_connection(db_path=DATABASE_PATH): | |
| conn = sqlite3.connect(db_path) | |
| try: | |
| yield conn | |
| finally: | |
| conn.close() | |
| def initialize_database(db_path=DATABASE_PATH): | |
| with get_db_connection(db_path) as conn: | |
| cursor = conn.cursor() | |
| # Tabellen-Erstellung (scheint korrekt, keine Auffälligkeiten) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS generation_logs ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| timestamp TEXT, | |
| prompt TEXT, | |
| optimized_prompt TEXT, | |
| hf_lora TEXT, | |
| lora_scale REAL, | |
| aspect_ratio TEXT, | |
| guidance_scale REAL, | |
| output_quality INTEGER, | |
| prompt_strength REAL, | |
| num_inference_steps INTEGER, | |
| output_file TEXT, | |
| album_id INTEGER, | |
| category_id INTEGER | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS albums ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| name TEXT NOT NULL | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS categories ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| name TEXT NOT NULL | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS pictures ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| timestamp TEXT, | |
| file_path TEXT, | |
| file_name TEXT, | |
| album_id INTEGER, | |
| FOREIGN KEY (album_id) REFERENCES albums(id) | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS picture_categories ( | |
| picture_id INTEGER, | |
| category_id INTEGER, | |
| FOREIGN KEY (picture_id) REFERENCES pictures(id), | |
| FOREIGN KEY (category_id) REFERENCES categories(id), | |
| PRIMARY KEY (picture_id, category_id) | |
| ) | |
| """) | |
| conn.commit() | |
| def log_generation(args, optimized_prompt, image_file): | |
| file_path, file_name = os.path.split(image_file) | |
| try: | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO generation_logs ( | |
| timestamp, prompt, optimized_prompt, hf_lora, lora_scale, aspect_ratio, guidance_scale, | |
| output_quality, prompt_strength, num_inference_steps, output_file, album_id, category_id | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| args.prompt, | |
| optimized_prompt, | |
| args.hf_lora, | |
| args.lora_scale, | |
| args.aspect_ratio, | |
| args.guidance_scale, | |
| args.output_quality, | |
| args.prompt_strength, | |
| args.num_inference_steps, | |
| image_file, | |
| args.album_id, | |
| args.category_ids[0] if args.category_ids else None # Hier auf erstes Element zugreifen | |
| )) | |
| picture_id = cursor.lastrowid # Dies scheint nicht korrekt zu sein, da die ID für die Tabelle pictures benötigt wird | |
| cursor.execute(""" | |
| INSERT INTO pictures ( | |
| timestamp, file_path, file_name, album_id | |
| ) VALUES (?, ?, ?, ?) | |
| """, ( | |
| datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| file_path, | |
| file_name, | |
| args.album_id | |
| )) | |
| picture_id = cursor.lastrowid # Korrekte Zeile | |
| # Insert multiple categories | |
| for category_id in args.category_ids: | |
| cursor.execute(""" | |
| INSERT INTO picture_categories (picture_id, category_id) | |
| VALUES (?, ?) | |
| """, (picture_id, category_id)) | |
| conn.commit() | |
| except sqlite3.Error as e: | |
| print(f"Error logging generation: {e}") # Sollte durch logger.error ersetzt werden. | |
| def startup_event(): | |
| initialize_database() | |
| # Authentifizierungsfunktion | |
| def authenticate(credentials: HTTPBasicCredentials = Depends(security)): | |
| if credentials.username != USERNAME or credentials.password != PASSWORD: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Ungültige Anmeldedaten", | |
| headers={"WWW-Authenticate": "Basic"}, | |
| ) | |
| return credentials.username | |
| # Geschützte Route für archive.html | |
| def read_archive( | |
| request: Request, | |
| album: Optional[str] = Query(None), | |
| category: Optional[List[str]] = Query(None), | |
| search: Optional[str] = None, | |
| items_per_page: int = Query(30), | |
| page: int = Query(1), | |
| username: str = Depends(authenticate), | |
| ): | |
| album_id = int(album) if album and album.isdigit() else None | |
| category_ids = [int(cat) for cat in category] if category else [] | |
| offset = (page - 1) * items_per_page | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| query = """ | |
| SELECT gl.timestamp, gl.prompt, gl.optimized_prompt, gl.output_file, a.name as album, c.name as category | |
| FROM generation_logs gl | |
| LEFT JOIN albums a ON gl.album_id = a.id | |
| LEFT JOIN categories c ON gl.category_id = c.id | |
| WHERE 1=1 | |
| """ | |
| params = [] | |
| if album_id is not None: | |
| query += " AND gl.album_id = ?" | |
| params.append(album_id) | |
| if category_ids: | |
| query = """ | |
| SELECT gl.timestamp, gl.prompt, gl.optimized_prompt, gl.output_file, a.name as album, GROUP_CONCAT(c.name) as categories | |
| FROM generation_logs gl | |
| LEFT JOIN albums a ON gl.album_id = a.id | |
| LEFT JOIN picture_categories pc ON gl.id = pc.picture_id | |
| LEFT JOIN categories c ON pc.category_id = c.id | |
| WHERE 1=1 | |
| """ | |
| if album_id is not None: | |
| query += " AND gl.album_id = ?" | |
| params.append(album_id) | |
| query += " AND pc.category_id IN ({})".format(",".join("?" for _ in category_ids)) | |
| params.extend(category_ids) | |
| if search: | |
| query += " AND (gl.prompt LIKE ? OR gl.optimized_prompt LIKE ?)" | |
| params.append(f'%{search}%') | |
| params.append(f'%{search}%') | |
| query += " GROUP BY gl.id, gl.timestamp, gl.prompt, gl.optimized_prompt, gl.output_file, a.name ORDER BY gl.timestamp DESC LIMIT ? OFFSET ?" | |
| params.extend([items_per_page, offset]) | |
| cursor.execute(query, params) | |
| logs = cursor.fetchall() | |
| logs = [{ | |
| "timestamp": log[0], | |
| "prompt": log[1], | |
| "optimized_prompt": log[2], | |
| "output_file": log[3], | |
| "album": log[4], | |
| "category": log[5] | |
| } for log in logs] | |
| cursor.execute("SELECT id, name FROM albums") | |
| albums = cursor.fetchall() | |
| cursor.execute("SELECT id, name FROM categories") | |
| categories = cursor.fetchall() | |
| return templates.TemplateResponse("archive.html", { | |
| "request": request, | |
| "logs": logs, | |
| "albums": albums, | |
| "categories": categories, | |
| "selected_album": album, | |
| "selected_categories": category_ids, | |
| "search_query": search, | |
| "items_per_page": items_per_page, | |
| "page": page, | |
| "username": username, | |
| }) | |
| # Öffentliche Route | |
| def read_root(request: Request): | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT id, name FROM albums") | |
| albums = cursor.fetchall() | |
| cursor.execute("SELECT id, name FROM categories") | |
| categories = cursor.fetchall() | |
| return templates.TemplateResponse("index.html", {"request": request, "albums": albums, "categories": categories}) | |
| def read_backend(request: Request): | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT id, name FROM albums") | |
| albums = cursor.fetchall() | |
| cursor.execute("SELECT id, name FROM categories") | |
| categories = cursor.fetchall() | |
| return templates.TemplateResponse("backend.html", {"request": request, "albums": albums, "categories": categories}) | |
| async def get_backend_stats(): | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| # Anzahl der Bilder (aus der pictures-Tabelle) | |
| cursor.execute("SELECT COUNT(*) FROM pictures") | |
| total_images = cursor.fetchone()[0] | |
| # Alben-Statistiken (Anzahl) | |
| cursor.execute("SELECT COUNT(*) FROM albums") | |
| total_albums = cursor.fetchone()[0] | |
| # Kategorie-Statistiken (Anzahl) | |
| cursor.execute("SELECT COUNT(*) FROM categories") | |
| total_categories = cursor.fetchone()[0] | |
| # Monatliche Statistiken (Anzahl der Bilder pro Monat) | |
| cursor.execute(""" | |
| SELECT strftime('%Y-%m', timestamp) as month, COUNT(*) | |
| FROM pictures | |
| GROUP BY month | |
| ORDER BY month | |
| """) | |
| monthly_stats = [{"month": row[0], "count": row[1]} for row in cursor.fetchall()] | |
| # Speicherplatzberechnung | |
| total_size = 0 | |
| for filename in os.listdir(IMAGE_STORAGE_PATH): | |
| filepath = os.path.join(IMAGE_STORAGE_PATH, filename) | |
| if os.path.isfile(filepath): | |
| total_size += os.path.getsize(filepath) | |
| total_size_mb = total_size / (1024 * 1024) | |
| # Daten für die Kategorien-Statistik (Beispiel: Anzahl der Bilder pro Kategorie) | |
| cursor.execute(""" | |
| SELECT c.name, COUNT(pc.picture_id) | |
| FROM categories c | |
| LEFT JOIN picture_categories pc ON c.id = pc.category_id | |
| GROUP BY c.name | |
| """) | |
| category_stats = [{"name": row[0], "count": row[1]} for row in cursor.fetchall()] | |
| return { | |
| "total_images": total_images, | |
| "albums": { | |
| "total": total_albums | |
| }, | |
| "categories": { | |
| "total": total_categories, | |
| "data": category_stats | |
| }, | |
| "storage_usage_mb": total_size_mb, | |
| "monthly": monthly_stats | |
| } # Hier war die Klammer falsch gesetzt | |
| # Neue Routen für Alben | |
| async def get_albums(): | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT id, name FROM albums") | |
| result = cursor.fetchall() | |
| albums = [{"id": row[0], "name": row[1]} for row in result] | |
| return albums | |
| async def create_album_route(name: str = Form(...), description: Optional[str] = Form(None)): | |
| try: | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("INSERT INTO albums (name) VALUES (?)", (name,)) | |
| conn.commit() | |
| new_album_id = cursor.lastrowid | |
| return {"message": "Album erstellt", "id": new_album_id, "name": name} | |
| except sqlite3.Error as e: | |
| raise HTTPException(status_code=500, detail=f"Error creating album: {e}") | |
| async def delete_album(album_id: int): | |
| try: | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| # Lösche die Verknüpfungen in picture_categories | |
| cursor.execute("DELETE FROM picture_categories WHERE picture_id IN (SELECT id FROM pictures WHERE album_id = ?)", (album_id,)) | |
| # Lösche die Bilder aus der pictures-Tabelle | |
| cursor.execute("DELETE FROM pictures WHERE album_id = ?", (album_id,)) | |
| # Lösche die Einträge aus generation_logs | |
| cursor.execute("DELETE FROM generation_logs WHERE album_id = ?", (album_id,)) | |
| # Lösche das Album aus der albums-Tabelle | |
| cursor.execute("DELETE FROM albums WHERE id = ?", (album_id,)) | |
| conn.commit() | |
| return {"message": f"Album {album_id} und zugehörige Einträge gelöscht"} | |
| except sqlite3.Error as e: | |
| raise HTTPException(status_code=500, detail=f"Error deleting album: {e}") | |
| async def update_album(album_id: int, request: Request): | |
| data = await request.json() | |
| try: | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("UPDATE albums SET name = ? WHERE id = ?", (data["name"], album_id)) | |
| conn.commit() | |
| if cursor.rowcount == 0: | |
| raise HTTPException(status_code=404, detail=f"Album {album_id} nicht gefunden") | |
| return {"message": f"Album {album_id} aktualisiert"} | |
| except sqlite3.Error as e: | |
| raise HTTPException(status_code=500, detail=f"Error updating album: {e}") | |
| # Neue Routen für Kategorien | |
| async def get_categories(): | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT id, name FROM categories") | |
| result = cursor.fetchall() | |
| categories = [{"id": row[0], "name": row[1]} for row in result] | |
| return categories | |
| async def create_category_route(name: str = Form(...)): | |
| try: | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("INSERT INTO categories (name) VALUES (?)", (name,)) | |
| conn.commit() | |
| new_category_id = cursor.lastrowid | |
| return {"message": "Kategorie erstellt", "id": new_category_id, "name": name} | |
| except sqlite3.Error as e: | |
| raise HTTPException(status_code=500, detail=f"Error creating category: {e}") | |
| async def delete_category(category_id: int): | |
| try: | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| # Lösche die Verknüpfungen in picture_categories | |
| cursor.execute("DELETE FROM picture_categories WHERE category_id = ?", (category_id,)) | |
| # Lösche die Kategorie aus der categories-Tabelle | |
| cursor.execute("DELETE FROM categories WHERE id = ?", (category_id,)) | |
| conn.commit() | |
| return {"message": f"Kategorie {category_id} und zugehörige Einträge gelöscht"} | |
| except sqlite3.Error as e: | |
| raise HTTPException(status_code=500, detail=f"Error deleting category: {e}") | |
| async def update_category(category_id: int, request: Request): | |
| data = await request.json() | |
| try: | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("UPDATE categories SET name = ? WHERE id = ?", (data["name"], category_id)) | |
| conn.commit() | |
| if cursor.rowcount == 0: | |
| raise HTTPException(status_code=404, detail=f"Kategorie {category_id} nicht gefunden") | |
| return {"message": f"Kategorie {category_id} aktualisiert"} | |
| except sqlite3.Error as e: | |
| raise HTTPException(status_code=500, detail=f"Error updating category: {e}") | |
| async def download_images(request: Request): | |
| try: | |
| body = await request.json() | |
| logger.info(f"Received request body: {body}") | |
| image_files = body.get("selectedImages", []) | |
| if not image_files: | |
| raise HTTPException(status_code=400, detail="Keine Bilder ausgewählt.") | |
| logger.info(f"Processing image files: {image_files}") | |
| # Überprüfe ob Download-Verzeichnis existiert | |
| if not os.path.exists(IMAGE_STORAGE_PATH): | |
| logger.error(f"Storage path not found: {IMAGE_STORAGE_PATH}") | |
| raise HTTPException(status_code=500, detail="Storage path not found") | |
| zip_buffer = BytesIO() | |
| with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: | |
| for image_file in image_files: | |
| image_path = os.path.join(IMAGE_STORAGE_PATH, image_file) | |
| logger.info(f"Processing file: {image_path}") | |
| if os.path.exists(image_path): | |
| zip_file.write(image_path, arcname=image_file) | |
| else: | |
| logger.error(f"File not found: {image_path}") | |
| raise HTTPException(status_code=404, detail=f"Bild {image_file} nicht gefunden.") | |
| zip_buffer.seek(0) | |
| # Korrekter Response mit Buffer | |
| return Response( | |
| content=zip_buffer.getvalue(), | |
| media_type="application/zip", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=images.zip" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in download_images: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def download_single_image(request: Request): | |
| try: | |
| data = await request.json() | |
| filename = data.get("filename") | |
| logger.info(f"Requested file download: {filename}") | |
| if not filename: | |
| logger.error("No filename provided") | |
| raise HTTPException(status_code=400, detail="Kein Dateiname angegeben") | |
| file_path = os.path.join(IMAGE_STORAGE_PATH, filename) | |
| logger.info(f"Full file path: {file_path}") | |
| if not os.path.exists(file_path): | |
| logger.error(f"File not found: {file_path}") | |
| raise HTTPException(status_code=404, detail=f"Datei {filename} nicht gefunden") | |
| # Determine MIME type | |
| file_extension = filename.lower().split('.')[-1] | |
| mime_types = { | |
| 'png': 'image/png', | |
| 'jpg': 'image/jpeg', | |
| 'jpeg': 'image/jpeg', | |
| 'gif': 'image/gif', | |
| 'webp': 'image/webp' | |
| } | |
| media_type = mime_types.get(file_extension, 'application/octet-stream') | |
| logger.info(f"Serving file with media type: {media_type}") | |
| return FileResponse( | |
| path=file_path, | |
| filename=filename, | |
| media_type=media_type, | |
| headers={ | |
| "Content-Disposition": f"attachment; filename={filename}" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in download_single_image: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| try: | |
| data = await websocket.receive_json() | |
| prompts = data.get("prompts", [data]) | |
| for prompt_data in prompts: | |
| prompt_data["lora_scale"] = float(prompt_data["lora_scale"]) | |
| prompt_data["guidance_scale"] = float(prompt_data["guidance_scale"]) | |
| prompt_data["prompt_strength"] = float(prompt_data["prompt_strength"]) | |
| prompt_data["num_inference_steps"] = int(prompt_data["num_inference_steps"]) | |
| prompt_data["num_outputs"] = int(prompt_data["num_outputs"]) | |
| prompt_data["output_quality"] = int(prompt_data["output_quality"]) | |
| # Handle new album and category creation | |
| album_name = prompt_data.get("album_id") | |
| category_names = prompt_data.get("category_ids", []) | |
| if album_name and not album_name.isdigit(): | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "INSERT INTO albums (name) VALUES (?)", (album_name,) | |
| ) | |
| conn.commit() | |
| prompt_data["album_id"] = cursor.lastrowid | |
| else: | |
| prompt_data["album_id"] = int(album_name) if album_name else None | |
| category_ids = [] | |
| for category_name in category_names: | |
| if not category_name.isdigit(): | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "INSERT INTO categories (name) VALUES (?)", (category_name,) | |
| ) | |
| conn.commit() | |
| category_ids.append(cursor.lastrowid) | |
| else: | |
| category_ids.append(int(category_name) if category_name else None) | |
| prompt_data["category_ids"] = category_ids | |
| args = argparse.Namespace(**prompt_data) | |
| # await websocket.send_json({"message": "Optimiere Prompt..."}) | |
| optimized_prompt = ( | |
| optimize_prompt(args.prompt) | |
| if getattr(args, "agent", False) | |
| else args.prompt | |
| ) | |
| await websocket.send_json({"optimized_prompt": optimized_prompt}) | |
| if prompt_data.get("optimize_only"): | |
| continue | |
| await generate_and_download_image(websocket, args, optimized_prompt) | |
| except WebSocketDisconnect: | |
| print("Client disconnected") | |
| except Exception as e: | |
| await websocket.send_json({"message": str(e)}) | |
| raise e | |
| finally: | |
| await websocket.close() | |
| async def fetch_image(item, index, args, filenames, semaphore, websocket, timestamp): | |
| async with semaphore: | |
| try: | |
| response = requests.get(item, timeout=TIMEOUT_DURATION) | |
| if response.status_code == 200: | |
| filename = ( | |
| f"{DOWNLOAD_DIR}/image_{timestamp}_{index}.{args.output_format}" | |
| ) | |
| with open(filename, "wb") as file: | |
| file.write(response.content) | |
| filenames.append( | |
| f"/flux-pics/image_{timestamp}_{index}.{args.output_format}" | |
| ) | |
| progress = int((index + 1) / args.num_outputs * 100) | |
| await websocket.send_json({"progress": progress}) | |
| else: | |
| await websocket.send_json( | |
| { | |
| "message": f"Fehler beim Herunterladen des Bildes {index + 1}: {response.status_code}" | |
| } | |
| ) | |
| except requests.exceptions.Timeout: | |
| await websocket.send_json( | |
| {"message": f"Timeout beim Herunterladen des Bildes {index + 1}"} | |
| ) | |
| async def generate_and_download_image(websocket: WebSocket, args, optimized_prompt): | |
| try: | |
| input_data = { | |
| "prompt": optimized_prompt, | |
| "hf_lora": getattr( | |
| args, "hf_lora", None | |
| ), # Use getattr to safely access hf_lora | |
| "lora_scale": args.lora_scale, | |
| "num_outputs": args.num_outputs, | |
| "aspect_ratio": args.aspect_ratio, | |
| "output_format": args.output_format, | |
| "guidance_scale": args.guidance_scale, | |
| "output_quality": args.output_quality, | |
| "prompt_strength": args.prompt_strength, | |
| "num_inference_steps": args.num_inference_steps, | |
| "disable_safety_checker": False, | |
| } | |
| # await websocket.send_json({"message": "Generiere Bilder..."}) | |
| # Debug: Log the start of the replication process | |
| print( | |
| f"Starting replication process for {args.num_outputs} outputs with timeout {TIMEOUT_DURATION}" | |
| ) | |
| output = replicate.run( | |
| "lucataco/flux-dev-lora:091495765fa5ef2725a175a57b276ec30dc9d39c22d30410f2ede68a3eab66b3", | |
| input=input_data, | |
| timeout=TIMEOUT_DURATION, | |
| ) | |
| if not os.path.exists(DOWNLOAD_DIR): | |
| os.makedirs(DOWNLOAD_DIR) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filenames = [] | |
| semaphore = Semaphore(3) # Limit concurrent downloads | |
| tasks = [ | |
| create_task( | |
| fetch_image( | |
| item, index, args, filenames, semaphore, websocket, timestamp | |
| ) | |
| ) | |
| for index, item in enumerate(output) | |
| ] | |
| await gather(*tasks) | |
| for file in filenames: | |
| log_generation(args, optimized_prompt, file) | |
| await websocket.send_json( | |
| {"message": "Bilder erfolgreich generiert", "generated_files": filenames} | |
| ) | |
| except requests.exceptions.Timeout: | |
| await websocket.send_json( | |
| {"message": "Fehler bei der Bildgenerierung: Timeout überschritten"} | |
| ) | |
| except Exception as e: | |
| await websocket.send_json( | |
| {"message": f"Fehler bei der Bildgenerierung: {str(e)}"} | |
| ) | |
| raise Exception(f"Fehler bei der Bildgenerierung: {str(e)}") | |
| def optimize_prompt(prompt): | |
| api_key = os.environ.get("MISTRAL_API_KEY") | |
| agent_id = os.environ.get("MISTRAL_FLUX_AGENT") | |
| if not api_key or not agent_id: | |
| raise ValueError("MISTRAL_API_KEY oder MISTRAL_FLUX_AGENT nicht gesetzt") | |
| client = Mistral(api_key=api_key) | |
| chat_response = client.agents.complete( | |
| agent_id=agent_id, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": f"Optimiere folgenden Prompt für Flux Lora: {prompt}", | |
| } | |
| ], | |
| ) | |
| return chat_response.choices[0].message.content | |
| if __name__ == "__main__": | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description="Beschreibung") | |
| parser.add_argument('--hf_lora', default=None, help='HF LoRA Model') | |
| args = parser.parse_args() | |
| # Pass arguments to the FastAPI application | |
| app.state.args = args | |
| # Run the Uvicorn server | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| reload=True, | |
| log_level="trace" | |
| ) | |