File size: 8,234 Bytes
38ab39c 1c47eb5 38ab39c 001e827 38ab39c 001e827 38ab39c 15a9a47 38ab39c 15a9a47 38ab39c 15a9a47 38ab39c 15a9a47 001e827 15a9a47 1c47eb5 15a9a47 1c47eb5 15a9a47 38ab39c 15a9a47 38ab39c 15a9a47 38ab39c 15a9a47 38ab39c 15a9a47 38ab39c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | import os
from huggingface_hub import InferenceClient
from dotenv import load_dotenv
from PIL import Image, ImageFile, UnidentifiedImageError
import io
import sys
# Do NOT load truncated images. We want to catch the error and retry or fallback.
ImageFile.LOAD_TRUNCATED_IMAGES = False
from cora_vision import CoraVision
from cora_memory import CoraMemory
class CoraEngine:
def __init__(self):
# 1. Configuration & Setup
load_dotenv()
self.HF_TOKEN = os.getenv("HF_API_TOKEN") or os.getenv("HF_TOKEN")
self.OLLAMA_HOST = os.getenv("OLLAMA_HOST") or "http://localhost:11434"
self.OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava")
# Migrated to FLUX.1-schnell (SOTA for fast open weights)
# Improved quality and speed over SDXL.
self.MODEL_ID = "black-forest-labs/FLUX.1-schnell"
self.FALLBACK_MODEL_ID = "stabilityai/stable-diffusion-2-1"
self.SYSTEM_PROMPT = ", historical social realism, ethnographic illustration, museum quality, natural window lighting, authentic period textures, oil on canvas, soot and wear, period accurate, sharp focus"
self.NEGATIVE_PROMPT = "fantasy, digital vibrancy, neon, plastic, 3d render, blur, low quality, jpeg artifacts, ugly, duplicate, mutilated, out of frame, extra fingers, mutated hands"
# Initialize RAG Components
try:
self.vision = CoraVision()
self.memory = CoraMemory()
except:
print("⚠️ Engine could not load Vision/Memory components. RAG Fallback disabled.")
self.vision = None
self.memory = None
if self.HF_TOKEN:
self.client = InferenceClient(api_key=self.HF_TOKEN)
else:
self.client = None
print("⚠️ Warning: HF_API_TOKEN or HF_TOKEN not found. Cloud image generation will fail.")
def analyze_image_with_ollama(self, image_path, prompt="Describe this image in detail."):
"""Uses Ollama Vision to describe an image."""
try:
import requests
import base64
with open(image_path, "rb") as f:
img_str = base64.b64encode(f.read()).decode()
url = f"{self.OLLAMA_HOST}/api/chat"
payload = {
"model": self.OLLAMA_VISION_MODEL,
"messages": [
{
"role": "user",
"content": prompt,
"images": [img_str]
}
],
"stream": False
}
response = requests.post(url, json=payload, timeout=60)
return response.json().get("message", {}).get("content")
except Exception as e:
print(f"Ollama Vision failed: {e}")
return None
def resize_image(self, image, max_size=1024):
"""Resizes image to ensure largest side is max_size, maintaining aspect ratio."""
if image is None:
return None
width, height = image.size
# Check if resize is actually needed
if max(width, height) <= max_size:
return image
ratio = max_size / max(width, height)
new_width = int(width * ratio)
new_height = int(height * ratio)
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
def generate_from_text(self, user_prompt, use_fallback=True):
"""
Text-to-Image generation via direct Hugging Face API with secondary model fallback and RAG.
"""
import requests
import time
from io import BytesIO
if not self.HF_TOKEN:
raise ValueError(
"Authentication error: Missing HF_API_TOKEN or HF_TOKEN."
)
final_prompt = f"{user_prompt}{self.SYSTEM_PROMPT}"
print(f"Archiving (Text): '{user_prompt}'...")
headers = {
"Authorization": f"Bearer {self.HF_TOKEN}",
"x-wait-for-model": "true"
}
def try_model(model_id, width, height, max_attempts=2):
# The old api-inference.huggingface.co is deprecated (410 Gone)
# Use the new router.huggingface.co/hf-inference endpoint
url = f"https://router.huggingface.co/hf-inference/models/{model_id}"
payload = {
"inputs": final_prompt,
"parameters": {
"width": width,
"height": height
}
}
for attempt in range(max_attempts):
try:
response = requests.post(url, headers=headers, json=payload, timeout=120)
if response.status_code == 200:
image = Image.open(BytesIO(response.content))
image.load()
print(f"✅ Received valid image from {model_id} ({image.format}, {image.size})")
return image
else:
resp_text = response.text.lower()
print(f"⚠️ Model {model_id} returned API Error {response.status_code}: {resp_text}")
# Return real error if it's 402/401
if response.status_code == 401:
raise ValueError(f"Auth error or gated repo for {model_id}")
if response.status_code == 402:
raise ValueError(f"Inference Provider limits reached for {model_id}")
# If 503 or model loading string, retry
if "loading" in resp_text or response.status_code == 503:
if attempt < max_attempts - 1:
print("Model loading... retrying in 5 seconds.")
time.sleep(5)
continue
raise ValueError(f"API Error {response.status_code}: {response.text}")
except Exception as e:
if attempt < max_attempts - 1:
time.sleep(2)
continue
raise e
return None
last_error = None
try:
return try_model(self.MODEL_ID, 1024, 1024)
except Exception as e:
last_error = e
err_name = type(e).__name__
err_msg = str(e).lower()
print(f"⚠️ Primary Generation Error [{err_name}]: {e}", file=sys.stderr)
if use_fallback:
print(f"⚠️ Primary model {self.MODEL_ID} failed. Trying fallback {self.FALLBACK_MODEL_ID}...")
try:
return try_model(self.FALLBACK_MODEL_ID, 768, 768)
except Exception as fe:
print(f"❌ Fallback model also failed: {fe}")
print(f"⚠️ Generation failed: {e}. Attempting RAG Fallback...")
# Visual RAG Fallback
if getattr(self, 'memory', None) and getattr(self, 'vision', None):
try:
emb = self.vision.embed_text(user_prompt)
results = self.memory.search_by_vector(emb, k=1)
if results and results.get('ids') and results['ids'][0]:
metadatas = results['metadatas'][0]
if metadatas:
path = metadatas[0].get('path')
if path and os.path.exists(path):
print(f"✅ RAG Fallback successful! Serving: {path}")
return Image.open(path)
except Exception as mem_e:
print(f"RAG Fallback failed: {mem_e}")
raise RuntimeError(f"Generation failed: {e}")
|