cora / api.py
tokgae's picture
Upload folder using huggingface_hub
1c47eb5 verified
from fastapi import FastAPI, Form, HTTPException, BackgroundTasks
from fastapi.responses import Response
from cora_engine import CoraEngine
from cora_curator import CoraCurator
from cora_vision import CoraVision
from cora_memory import CoraMemory
import io
import os
import uuid
from pydantic import BaseModel
app = FastAPI(title="Cora API", description="Fake Historical Archive Generator")
engine = CoraEngine()
curator = CoraCurator()
vision = CoraVision()
memory = CoraMemory()
class AgentPrompt(BaseModel):
prompt: str
use_curator: bool = True
@app.get("/health")
def health_check():
"""Checks if the engine and HF connection are ready."""
status = {"status": "online", "model": engine.MODEL_ID}
if not engine.client:
status["status"] = "offline (engine)"
if not curator.client:
status["curator"] = "offline"
else:
status["curator"] = curator.MODEL_ID
# Check Vision/Memory (simple check if initialized)
status["vision"] = "online" if vision.clip_model else "offline"
status["memory"] = "online" if memory.client else "offline"
return status
def archive_generation(image, prompt):
"""Helper to save image and metadata to Visual Memory."""
try:
filename = f"{uuid.uuid4()}.png"
filepath = os.path.join("archive_images", filename)
# Save to disk
image.save(filepath)
# Analyze (Vision)
embedding = vision.embed_image(image)
tags = vision.detect_tags(image)
# Save to Memory (Vector DB)
memory.save(filepath, embedding, prompt, tags)
print(f"✅ Background Archiving Complete: {filepath} with tags {tags}")
except Exception as e:
print(f"❌ Background Archiving Failed: {e}")
@app.post("/agent/generate")
async def agent_generate(request: AgentPrompt, background_tasks: BackgroundTasks):
"""
Agent-friendly endpoint receiving JSON.
Returns the raw PNG image.
"""
try:
# Validate input
if not request.prompt or not request.prompt.strip():
raise HTTPException(
status_code=400,
detail="Prompt cannot be empty. Please provide a description."
)
# 1. Curate (Refine Prompt)
final_prompt = request.prompt
if request.use_curator:
try:
final_prompt = curator.refine_prompt(request.prompt)
except Exception as curator_error:
print(f"Curator failed: {curator_error}, using original prompt")
# Fallback to original if curator fails
final_prompt = request.prompt
# 2. Generate
result = engine.generate_from_text(final_prompt)
# 3. Archive (Background Task)
# We pass a copy or the object itself. Since PIL images are in memory,
# we need to be careful. However, 'result' is a PIL Image.
# It's safer to pass the image object. background_tasks will run after return.
background_tasks.add_task(archive_generation, result, final_prompt)
# Return as PNG
img_byte_arr = io.BytesIO()
result.save(img_byte_arr, format='PNG')
return Response(content=img_byte_arr.getvalue(), media_type="image/png")
except HTTPException:
raise
except ValueError as e:
# User input errors
raise HTTPException(
status_code=400,
detail=f"Invalid request: {str(e)}"
)
except RuntimeError as e:
# Server/API errors
error_msg = str(e).lower()
if "timeout" in error_msg or "took too long" in error_msg:
raise HTTPException(
status_code=500,
detail="Image generation timed out. Try a simpler prompt."
)
else:
raise HTTPException(
status_code=500,
detail=f"Generation failed: {str(e)}"
)
except Exception as e:
print(f"Unexpected server error: {e}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred. Please try again."
)
@app.post("/v1/archive")
async def generate_archive(
background_tasks: BackgroundTasks,
prompt: str = Form(...)
):
"""
Generates an 'archive' style image from text.
"""
try:
# 1. Curate (Auto-refine for UI)
enhanced_prompt = curator.refine_prompt(prompt)
# 2. Generate
result = engine.generate_from_text(enhanced_prompt)
# 3. Archive (Background Task)
background_tasks.add_task(archive_generation, result, enhanced_prompt)
# Return as PNG
img_byte_arr = io.BytesIO()
result.save(img_byte_arr, format='PNG')
return Response(content=img_byte_arr.getvalue(), media_type="image/png")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except RuntimeError as e:
raise HTTPException(status_code=500, detail=str(e))
except Exception as e:
print(f"Server Error: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
class SearchQuery(BaseModel):
query: str
limit: int = 10
@app.post("/curator/search")
async def curator_search(request: SearchQuery):
"""
Semantic search for the UI gallery with intelligent filtering.
"""
try:
# 1. Embed query
emb = vision.embed_text(request.query)
if not emb:
return {"results": []}
# 2. Extract potential tags from query for filtering
query_lower = request.query.lower()
tag_hints = []
source_hint = None
# Detect cultural/temporal keywords
cultural_markers = {
"roman": ["roman", "rome"],
"greek": ["greek", "greece", "hellenic"],
"egyptian": ["egypt", "egyptian"],
"medieval": ["medieval", "middle ages"],
"renaissance": ["renaissance"],
"enlightment century": ["enlightment century"],
"industrial revolution":["industrial revolution"],
"modern times" : ["modern times", "20th century", "21st century"],
}
for culture, keywords in cultural_markers.items():
if any(kw in query_lower for kw in keywords):
tag_hints.extend(keywords)
# 3. Use hybrid search if we detected cultural markers
if tag_hints:
results = memory.search_hybrid(emb, k=request.limit, tag_filter=tag_hints)
else:
# Fallback to pure semantic if no specific markers
results = memory.search_by_vector(emb, k=request.limit)
# 4. Format result
images = []
if results['ids']:
ids = results['ids'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
for i, uid in enumerate(ids):
path = metadatas[i].get('path')
tags = metadatas[i].get('tags')
prompt = metadatas[i].get('prompt')
if path and os.path.exists(path):
# Convert local path to URL
filename = os.path.basename(path)
image_url = f"http://localhost:8000/archive_images/{filename}"
images.append({
"path": image_url, # Now a URL, not a local path
"tags": tags,
"prompt": prompt,
"score": float(distances[i])
})
return {"results": images}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Mount static files to serve images to UI if needed
from fastapi.staticfiles import StaticFiles
if not os.path.exists("archive_images"):
os.makedirs("archive_images")
app.mount("/archive_images", StaticFiles(directory="archive_images"), name="archive_images")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)