cora / cora_engine.py
tokgae's picture
Upload folder using huggingface_hub
001e827 verified
import os
from huggingface_hub import InferenceClient
from dotenv import load_dotenv
from PIL import Image, ImageFile, UnidentifiedImageError
import io
import sys
# Do NOT load truncated images. We want to catch the error and retry or fallback.
ImageFile.LOAD_TRUNCATED_IMAGES = False
from cora_vision import CoraVision
from cora_memory import CoraMemory
class CoraEngine:
def __init__(self):
# 1. Configuration & Setup
load_dotenv()
self.HF_TOKEN = os.getenv("HF_API_TOKEN") or os.getenv("HF_TOKEN")
self.OLLAMA_HOST = os.getenv("OLLAMA_HOST") or "http://localhost:11434"
self.OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL", "llava")
# Migrated to FLUX.1-schnell (SOTA for fast open weights)
# Improved quality and speed over SDXL.
self.MODEL_ID = "black-forest-labs/FLUX.1-schnell"
self.FALLBACK_MODEL_ID = "stabilityai/stable-diffusion-2-1"
self.SYSTEM_PROMPT = ", historical social realism, ethnographic illustration, museum quality, natural window lighting, authentic period textures, oil on canvas, soot and wear, period accurate, sharp focus"
self.NEGATIVE_PROMPT = "fantasy, digital vibrancy, neon, plastic, 3d render, blur, low quality, jpeg artifacts, ugly, duplicate, mutilated, out of frame, extra fingers, mutated hands"
# Initialize RAG Components
try:
self.vision = CoraVision()
self.memory = CoraMemory()
except:
print("⚠️ Engine could not load Vision/Memory components. RAG Fallback disabled.")
self.vision = None
self.memory = None
if self.HF_TOKEN:
self.client = InferenceClient(api_key=self.HF_TOKEN)
else:
self.client = None
print("⚠️ Warning: HF_API_TOKEN or HF_TOKEN not found. Cloud image generation will fail.")
def analyze_image_with_ollama(self, image_path, prompt="Describe this image in detail."):
"""Uses Ollama Vision to describe an image."""
try:
import requests
import base64
with open(image_path, "rb") as f:
img_str = base64.b64encode(f.read()).decode()
url = f"{self.OLLAMA_HOST}/api/chat"
payload = {
"model": self.OLLAMA_VISION_MODEL,
"messages": [
{
"role": "user",
"content": prompt,
"images": [img_str]
}
],
"stream": False
}
response = requests.post(url, json=payload, timeout=60)
return response.json().get("message", {}).get("content")
except Exception as e:
print(f"Ollama Vision failed: {e}")
return None
def resize_image(self, image, max_size=1024):
"""Resizes image to ensure largest side is max_size, maintaining aspect ratio."""
if image is None:
return None
width, height = image.size
# Check if resize is actually needed
if max(width, height) <= max_size:
return image
ratio = max_size / max(width, height)
new_width = int(width * ratio)
new_height = int(height * ratio)
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
def generate_from_text(self, user_prompt, use_fallback=True):
"""
Text-to-Image generation via direct Hugging Face API with secondary model fallback and RAG.
"""
import requests
import time
from io import BytesIO
if not self.HF_TOKEN:
raise ValueError(
"Authentication error: Missing HF_API_TOKEN or HF_TOKEN."
)
final_prompt = f"{user_prompt}{self.SYSTEM_PROMPT}"
print(f"Archiving (Text): '{user_prompt}'...")
headers = {
"Authorization": f"Bearer {self.HF_TOKEN}",
"x-wait-for-model": "true"
}
def try_model(model_id, width, height, max_attempts=2):
# The old api-inference.huggingface.co is deprecated (410 Gone)
# Use the new router.huggingface.co/hf-inference endpoint
url = f"https://router.huggingface.co/hf-inference/models/{model_id}"
payload = {
"inputs": final_prompt,
"parameters": {
"width": width,
"height": height
}
}
for attempt in range(max_attempts):
try:
response = requests.post(url, headers=headers, json=payload, timeout=120)
if response.status_code == 200:
image = Image.open(BytesIO(response.content))
image.load()
print(f"✅ Received valid image from {model_id} ({image.format}, {image.size})")
return image
else:
resp_text = response.text.lower()
print(f"⚠️ Model {model_id} returned API Error {response.status_code}: {resp_text}")
# Return real error if it's 402/401
if response.status_code == 401:
raise ValueError(f"Auth error or gated repo for {model_id}")
if response.status_code == 402:
raise ValueError(f"Inference Provider limits reached for {model_id}")
# If 503 or model loading string, retry
if "loading" in resp_text or response.status_code == 503:
if attempt < max_attempts - 1:
print("Model loading... retrying in 5 seconds.")
time.sleep(5)
continue
raise ValueError(f"API Error {response.status_code}: {response.text}")
except Exception as e:
if attempt < max_attempts - 1:
time.sleep(2)
continue
raise e
return None
last_error = None
try:
return try_model(self.MODEL_ID, 1024, 1024)
except Exception as e:
last_error = e
err_name = type(e).__name__
err_msg = str(e).lower()
print(f"⚠️ Primary Generation Error [{err_name}]: {e}", file=sys.stderr)
if use_fallback:
print(f"⚠️ Primary model {self.MODEL_ID} failed. Trying fallback {self.FALLBACK_MODEL_ID}...")
try:
return try_model(self.FALLBACK_MODEL_ID, 768, 768)
except Exception as fe:
print(f"❌ Fallback model also failed: {fe}")
print(f"⚠️ Generation failed: {e}. Attempting RAG Fallback...")
# Visual RAG Fallback
if getattr(self, 'memory', None) and getattr(self, 'vision', None):
try:
emb = self.vision.embed_text(user_prompt)
results = self.memory.search_by_vector(emb, k=1)
if results and results.get('ids') and results['ids'][0]:
metadatas = results['metadatas'][0]
if metadatas:
path = metadatas[0].get('path')
if path and os.path.exists(path):
print(f"✅ RAG Fallback successful! Serving: {path}")
return Image.open(path)
except Exception as mem_e:
print(f"RAG Fallback failed: {mem_e}")
raise RuntimeError(f"Generation failed: {e}")