| from fastapi import FastAPI, Form, HTTPException, BackgroundTasks
|
| from fastapi.responses import Response
|
| from cora_engine import CoraEngine
|
| from cora_curator import CoraCurator
|
| from cora_vision import CoraVision
|
| from cora_memory import CoraMemory
|
| import io
|
| import os
|
| import uuid
|
|
|
| from pydantic import BaseModel
|
|
|
| app = FastAPI(title="Cora API", description="Fake Historical Archive Generator")
|
| engine = CoraEngine()
|
| curator = CoraCurator()
|
| vision = CoraVision()
|
| memory = CoraMemory()
|
|
|
| class AgentPrompt(BaseModel):
|
| prompt: str
|
| use_curator: bool = True
|
|
|
| @app.get("/health")
|
| def health_check():
|
| """Checks if the engine and HF connection are ready."""
|
| status = {"status": "online", "model": engine.MODEL_ID}
|
| if not engine.client:
|
| status["status"] = "offline (engine)"
|
| if not curator.client:
|
| status["curator"] = "offline"
|
| else:
|
| status["curator"] = curator.MODEL_ID
|
|
|
|
|
| status["vision"] = "online" if vision.clip_model else "offline"
|
| status["memory"] = "online" if memory.client else "offline"
|
|
|
| return status
|
|
|
| def archive_generation(image, prompt):
|
| """Helper to save image and metadata to Visual Memory."""
|
| try:
|
| filename = f"{uuid.uuid4()}.png"
|
| filepath = os.path.join("archive_images", filename)
|
|
|
|
|
| image.save(filepath)
|
|
|
|
|
| embedding = vision.embed_image(image)
|
| tags = vision.detect_tags(image)
|
|
|
|
|
| memory.save(filepath, embedding, prompt, tags)
|
| print(f"✅ Background Archiving Complete: {filepath} with tags {tags}")
|
| except Exception as e:
|
| print(f"❌ Background Archiving Failed: {e}")
|
|
|
| @app.post("/agent/generate")
|
| async def agent_generate(request: AgentPrompt, background_tasks: BackgroundTasks):
|
| """
|
| Agent-friendly endpoint receiving JSON.
|
| Returns the raw PNG image.
|
| """
|
| try:
|
|
|
| if not request.prompt or not request.prompt.strip():
|
| raise HTTPException(
|
| status_code=400,
|
| detail="Prompt cannot be empty. Please provide a description."
|
| )
|
|
|
|
|
| final_prompt = request.prompt
|
| if request.use_curator:
|
| try:
|
| final_prompt = curator.refine_prompt(request.prompt)
|
| except Exception as curator_error:
|
| print(f"Curator failed: {curator_error}, using original prompt")
|
|
|
| final_prompt = request.prompt
|
|
|
|
|
| result = engine.generate_from_text(final_prompt)
|
|
|
|
|
|
|
|
|
|
|
| background_tasks.add_task(archive_generation, result, final_prompt)
|
|
|
|
|
| img_byte_arr = io.BytesIO()
|
| result.save(img_byte_arr, format='PNG')
|
| return Response(content=img_byte_arr.getvalue(), media_type="image/png")
|
|
|
| except HTTPException:
|
| raise
|
| except ValueError as e:
|
|
|
| raise HTTPException(
|
| status_code=400,
|
| detail=f"Invalid request: {str(e)}"
|
| )
|
| except RuntimeError as e:
|
|
|
| error_msg = str(e).lower()
|
| if "timeout" in error_msg or "took too long" in error_msg:
|
| raise HTTPException(
|
| status_code=500,
|
| detail="Image generation timed out. Try a simpler prompt."
|
| )
|
| else:
|
| raise HTTPException(
|
| status_code=500,
|
| detail=f"Generation failed: {str(e)}"
|
| )
|
| except Exception as e:
|
| print(f"Unexpected server error: {e}")
|
| raise HTTPException(
|
| status_code=500,
|
| detail="An unexpected error occurred. Please try again."
|
| )
|
|
|
| @app.post("/v1/archive")
|
| async def generate_archive(
|
| background_tasks: BackgroundTasks,
|
| prompt: str = Form(...)
|
| ):
|
| """
|
| Generates an 'archive' style image from text.
|
| """
|
| try:
|
|
|
| enhanced_prompt = curator.refine_prompt(prompt)
|
|
|
|
|
| result = engine.generate_from_text(enhanced_prompt)
|
|
|
|
|
| background_tasks.add_task(archive_generation, result, enhanced_prompt)
|
|
|
|
|
| img_byte_arr = io.BytesIO()
|
| result.save(img_byte_arr, format='PNG')
|
| return Response(content=img_byte_arr.getvalue(), media_type="image/png")
|
|
|
| except ValueError as e:
|
| raise HTTPException(status_code=400, detail=str(e))
|
| except RuntimeError as e:
|
| raise HTTPException(status_code=500, detail=str(e))
|
| except Exception as e:
|
| print(f"Server Error: {e}")
|
| raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
|
| class SearchQuery(BaseModel):
|
| query: str
|
| limit: int = 10
|
|
|
| @app.post("/curator/search")
|
| async def curator_search(request: SearchQuery):
|
| """
|
| Semantic search for the UI gallery with intelligent filtering.
|
| """
|
| try:
|
|
|
| emb = vision.embed_text(request.query)
|
| if not emb:
|
| return {"results": []}
|
|
|
|
|
| query_lower = request.query.lower()
|
| tag_hints = []
|
| source_hint = None
|
|
|
|
|
| cultural_markers = {
|
| "roman": ["roman", "rome"],
|
| "greek": ["greek", "greece", "hellenic"],
|
| "egyptian": ["egypt", "egyptian"],
|
| "medieval": ["medieval", "middle ages"],
|
| "renaissance": ["renaissance"],
|
| "enlightment century": ["enlightment century"],
|
| "industrial revolution":["industrial revolution"],
|
| "modern times" : ["modern times", "20th century", "21st century"],
|
| }
|
|
|
| for culture, keywords in cultural_markers.items():
|
| if any(kw in query_lower for kw in keywords):
|
| tag_hints.extend(keywords)
|
|
|
|
|
| if tag_hints:
|
| results = memory.search_hybrid(emb, k=request.limit, tag_filter=tag_hints)
|
| else:
|
|
|
| results = memory.search_by_vector(emb, k=request.limit)
|
|
|
|
|
| images = []
|
| if results['ids']:
|
| ids = results['ids'][0]
|
| metadatas = results['metadatas'][0]
|
| distances = results['distances'][0]
|
|
|
| for i, uid in enumerate(ids):
|
| path = metadatas[i].get('path')
|
| tags = metadatas[i].get('tags')
|
| prompt = metadatas[i].get('prompt')
|
| if path and os.path.exists(path):
|
|
|
| filename = os.path.basename(path)
|
| image_url = f"http://localhost:8000/archive_images/{filename}"
|
|
|
| images.append({
|
| "path": image_url,
|
| "tags": tags,
|
| "prompt": prompt,
|
| "score": float(distances[i])
|
| })
|
| return {"results": images}
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
| from fastapi.staticfiles import StaticFiles
|
| if not os.path.exists("archive_images"):
|
| os.makedirs("archive_images")
|
| app.mount("/archive_images", StaticFiles(directory="archive_images"), name="archive_images")
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
| uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|