| import os
|
| from huggingface_hub import InferenceClient
|
| from dotenv import load_dotenv
|
| from PIL import Image, ImageFile, UnidentifiedImageError
|
| import io
|
| import sys
|
|
|
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = False
|
|
|
| from cora_vision import CoraVision
|
| from cora_memory import CoraMemory
|
|
|
| class CoraEngine:
|
| def __init__(self):
|
|
|
| load_dotenv()
|
| self.HF_TOKEN = os.getenv("HF_API_TOKEN") or os.getenv("HF_TOKEN")
|
| self.OLLAMA_HOST = os.getenv("OLLAMA_HOST") or "http://localhost:11434"
|
| self.OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava")
|
|
|
|
|
|
|
| self.MODEL_ID = "black-forest-labs/FLUX.1-schnell"
|
| self.FALLBACK_MODEL_ID = "stabilityai/stable-diffusion-2-1"
|
| self.SYSTEM_PROMPT = ", historical social realism, ethnographic illustration, museum quality, natural window lighting, authentic period textures, oil on canvas, soot and wear, period accurate, sharp focus"
|
| self.NEGATIVE_PROMPT = "fantasy, digital vibrancy, neon, plastic, 3d render, blur, low quality, jpeg artifacts, ugly, duplicate, mutilated, out of frame, extra fingers, mutated hands"
|
|
|
|
|
| try:
|
| self.vision = CoraVision()
|
| self.memory = CoraMemory()
|
| except:
|
| print("⚠️ Engine could not load Vision/Memory components. RAG Fallback disabled.")
|
| self.vision = None
|
| self.memory = None
|
|
|
| if self.HF_TOKEN:
|
| self.client = InferenceClient(api_key=self.HF_TOKEN)
|
| else:
|
| self.client = None
|
| print("⚠️ Warning: HF_API_TOKEN or HF_TOKEN not found. Cloud image generation will fail.")
|
|
|
| def analyze_image_with_ollama(self, image_path, prompt="Describe this image in detail."):
|
| """Uses Ollama Vision to describe an image."""
|
| try:
|
| import requests
|
| import base64
|
|
|
| with open(image_path, "rb") as f:
|
| img_str = base64.b64encode(f.read()).decode()
|
|
|
| url = f"{self.OLLAMA_HOST}/api/chat"
|
| payload = {
|
| "model": self.OLLAMA_VISION_MODEL,
|
| "messages": [
|
| {
|
| "role": "user",
|
| "content": prompt,
|
| "images": [img_str]
|
| }
|
| ],
|
| "stream": False
|
| }
|
| response = requests.post(url, json=payload, timeout=60)
|
| return response.json().get("message", {}).get("content")
|
| except Exception as e:
|
| print(f"Ollama Vision failed: {e}")
|
| return None
|
|
|
| def resize_image(self, image, max_size=1024):
|
| """Resizes image to ensure largest side is max_size, maintaining aspect ratio."""
|
| if image is None:
|
| return None
|
| width, height = image.size
|
|
|
|
|
| if max(width, height) <= max_size:
|
| return image
|
|
|
| ratio = max_size / max(width, height)
|
| new_width = int(width * ratio)
|
| new_height = int(height * ratio)
|
| return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
| def generate_from_text(self, user_prompt, use_fallback=True):
|
| """
|
| Text-to-Image generation via direct Hugging Face API with secondary model fallback and RAG.
|
| """
|
| import requests
|
| import time
|
| from io import BytesIO
|
|
|
| if not self.HF_TOKEN:
|
| raise ValueError(
|
| "Authentication error: Missing HF_API_TOKEN or HF_TOKEN."
|
| )
|
|
|
| final_prompt = f"{user_prompt}{self.SYSTEM_PROMPT}"
|
| print(f"Archiving (Text): '{user_prompt}'...")
|
|
|
| headers = {
|
| "Authorization": f"Bearer {self.HF_TOKEN}",
|
| "x-wait-for-model": "true"
|
| }
|
|
|
| def try_model(model_id, width, height, max_attempts=2):
|
|
|
|
|
| url = f"https://router.huggingface.co/hf-inference/models/{model_id}"
|
| payload = {
|
| "inputs": final_prompt,
|
| "parameters": {
|
| "width": width,
|
| "height": height
|
| }
|
| }
|
|
|
| for attempt in range(max_attempts):
|
| try:
|
| response = requests.post(url, headers=headers, json=payload, timeout=120)
|
| if response.status_code == 200:
|
| image = Image.open(BytesIO(response.content))
|
| image.load()
|
| print(f"✅ Received valid image from {model_id} ({image.format}, {image.size})")
|
| return image
|
| else:
|
| resp_text = response.text.lower()
|
| print(f"⚠️ Model {model_id} returned API Error {response.status_code}: {resp_text}")
|
|
|
| if response.status_code == 401:
|
| raise ValueError(f"Auth error or gated repo for {model_id}")
|
| if response.status_code == 402:
|
| raise ValueError(f"Inference Provider limits reached for {model_id}")
|
|
|
| if "loading" in resp_text or response.status_code == 503:
|
| if attempt < max_attempts - 1:
|
| print("Model loading... retrying in 5 seconds.")
|
| time.sleep(5)
|
| continue
|
| raise ValueError(f"API Error {response.status_code}: {response.text}")
|
| except Exception as e:
|
| if attempt < max_attempts - 1:
|
| time.sleep(2)
|
| continue
|
| raise e
|
| return None
|
|
|
| last_error = None
|
| try:
|
| return try_model(self.MODEL_ID, 1024, 1024)
|
| except Exception as e:
|
| last_error = e
|
| err_name = type(e).__name__
|
| err_msg = str(e).lower()
|
|
|
| print(f"⚠️ Primary Generation Error [{err_name}]: {e}", file=sys.stderr)
|
|
|
| if use_fallback:
|
| print(f"⚠️ Primary model {self.MODEL_ID} failed. Trying fallback {self.FALLBACK_MODEL_ID}...")
|
| try:
|
| return try_model(self.FALLBACK_MODEL_ID, 768, 768)
|
| except Exception as fe:
|
| print(f"❌ Fallback model also failed: {fe}")
|
|
|
| print(f"⚠️ Generation failed: {e}. Attempting RAG Fallback...")
|
|
|
|
|
| if getattr(self, 'memory', None) and getattr(self, 'vision', None):
|
| try:
|
| emb = self.vision.embed_text(user_prompt)
|
| results = self.memory.search_by_vector(emb, k=1)
|
|
|
| if results and results.get('ids') and results['ids'][0]:
|
| metadatas = results['metadatas'][0]
|
| if metadatas:
|
| path = metadatas[0].get('path')
|
| if path and os.path.exists(path):
|
| print(f"✅ RAG Fallback successful! Serving: {path}")
|
| return Image.open(path)
|
| except Exception as mem_e:
|
| print(f"RAG Fallback failed: {mem_e}")
|
|
|
| raise RuntimeError(f"Generation failed: {e}")
|
|
|