import os from huggingface_hub import InferenceClient from dotenv import load_dotenv from PIL import Image, ImageFile, UnidentifiedImageError import io import sys # Do NOT load truncated images. We want to catch the error and retry or fallback. ImageFile.LOAD_TRUNCATED_IMAGES = False from cora_vision import CoraVision from cora_memory import CoraMemory class CoraEngine: def __init__(self): # 1. Configuration & Setup load_dotenv() self.HF_TOKEN = os.getenv("HF_API_TOKEN") or os.getenv("HF_TOKEN") self.OLLAMA_HOST = os.getenv("OLLAMA_HOST") or "http://localhost:11434" self.OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava") # Migrated to FLUX.1-schnell (SOTA for fast open weights) # Improved quality and speed over SDXL. self.MODEL_ID = "black-forest-labs/FLUX.1-schnell" self.FALLBACK_MODEL_ID = "stabilityai/stable-diffusion-2-1" self.SYSTEM_PROMPT = ", historical social realism, ethnographic illustration, museum quality, natural window lighting, authentic period textures, oil on canvas, soot and wear, period accurate, sharp focus" self.NEGATIVE_PROMPT = "fantasy, digital vibrancy, neon, plastic, 3d render, blur, low quality, jpeg artifacts, ugly, duplicate, mutilated, out of frame, extra fingers, mutated hands" # Initialize RAG Components try: self.vision = CoraVision() self.memory = CoraMemory() except: print("⚠️ Engine could not load Vision/Memory components. RAG Fallback disabled.") self.vision = None self.memory = None if self.HF_TOKEN: self.client = InferenceClient(api_key=self.HF_TOKEN) else: self.client = None print("⚠️ Warning: HF_API_TOKEN or HF_TOKEN not found. Cloud image generation will fail.") def analyze_image_with_ollama(self, image_path, prompt="Describe this image in detail."): """Uses Ollama Vision to describe an image.""" try: import requests import base64 with open(image_path, "rb") as f: img_str = base64.b64encode(f.read()).decode() url = f"{self.OLLAMA_HOST}/api/chat" payload = { "model": self.OLLAMA_VISION_MODEL, "messages": [ { "role": "user", "content": prompt, "images": [img_str] } ], "stream": False } response = requests.post(url, json=payload, timeout=60) return response.json().get("message", {}).get("content") except Exception as e: print(f"Ollama Vision failed: {e}") return None def resize_image(self, image, max_size=1024): """Resizes image to ensure largest side is max_size, maintaining aspect ratio.""" if image is None: return None width, height = image.size # Check if resize is actually needed if max(width, height) <= max_size: return image ratio = max_size / max(width, height) new_width = int(width * ratio) new_height = int(height * ratio) return image.resize((new_width, new_height), Image.Resampling.LANCZOS) def generate_from_text(self, user_prompt, use_fallback=True): """ Text-to-Image generation via direct Hugging Face API with secondary model fallback and RAG. """ import requests import time from io import BytesIO if not self.HF_TOKEN: raise ValueError( "Authentication error: Missing HF_API_TOKEN or HF_TOKEN." ) final_prompt = f"{user_prompt}{self.SYSTEM_PROMPT}" print(f"Archiving (Text): '{user_prompt}'...") headers = { "Authorization": f"Bearer {self.HF_TOKEN}", "x-wait-for-model": "true" } def try_model(model_id, width, height, max_attempts=2): # The old api-inference.huggingface.co is deprecated (410 Gone) # Use the new router.huggingface.co/hf-inference endpoint url = f"https://router.huggingface.co/hf-inference/models/{model_id}" payload = { "inputs": final_prompt, "parameters": { "width": width, "height": height } } for attempt in range(max_attempts): try: response = requests.post(url, headers=headers, json=payload, timeout=120) if response.status_code == 200: image = Image.open(BytesIO(response.content)) image.load() print(f"✅ Received valid image from {model_id} ({image.format}, {image.size})") return image else: resp_text = response.text.lower() print(f"⚠️ Model {model_id} returned API Error {response.status_code}: {resp_text}") # Return real error if it's 402/401 if response.status_code == 401: raise ValueError(f"Auth error or gated repo for {model_id}") if response.status_code == 402: raise ValueError(f"Inference Provider limits reached for {model_id}") # If 503 or model loading string, retry if "loading" in resp_text or response.status_code == 503: if attempt < max_attempts - 1: print("Model loading... retrying in 5 seconds.") time.sleep(5) continue raise ValueError(f"API Error {response.status_code}: {response.text}") except Exception as e: if attempt < max_attempts - 1: time.sleep(2) continue raise e return None last_error = None try: return try_model(self.MODEL_ID, 1024, 1024) except Exception as e: last_error = e err_name = type(e).__name__ err_msg = str(e).lower() print(f"⚠️ Primary Generation Error [{err_name}]: {e}", file=sys.stderr) if use_fallback: print(f"⚠️ Primary model {self.MODEL_ID} failed. Trying fallback {self.FALLBACK_MODEL_ID}...") try: return try_model(self.FALLBACK_MODEL_ID, 768, 768) except Exception as fe: print(f"❌ Fallback model also failed: {fe}") print(f"⚠️ Generation failed: {e}. Attempting RAG Fallback...") # Visual RAG Fallback if getattr(self, 'memory', None) and getattr(self, 'vision', None): try: emb = self.vision.embed_text(user_prompt) results = self.memory.search_by_vector(emb, k=1) if results and results.get('ids') and results['ids'][0]: metadatas = results['metadatas'][0] if metadatas: path = metadatas[0].get('path') if path and os.path.exists(path): print(f"✅ RAG Fallback successful! Serving: {path}") return Image.open(path) except Exception as mem_e: print(f"RAG Fallback failed: {mem_e}") raise RuntimeError(f"Generation failed: {e}")