prelington's picture
Create cognito-server/app.py
8ca1a6a verified
# app.py
import os
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel
import httpx
from typing import Optional, List
from dotenv import load_dotenv
load_dotenv()
HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Put your Hugging Face token here if using HF Inference API
MODEL_ID = os.getenv("HF_MODEL_ID", "gpt2") # default, replace with your Cognito model id
RUN_MODE = os.getenv("RUN_MODE", "inference_api") # or "local"
# If RUN_MODE == "local", we will lazily import transformers and create a pipeline
local_pipeline = None
app = FastAPI(title="Cognito - Chat API")
class ChatRequest(BaseModel):
session_id: Optional[str] = None
messages: List[dict] # [{"role":"user","content":"..."} , ...]
fetch_url: Optional[str] = None # optional: server will fetch a URL (only if connector authorized)
fetch_auth: Optional[dict] = None # optional auth info for connector (see docs)
def call_hf_inference_api(prompt: str):
if not HF_API_TOKEN:
raise HTTPException(status_code=500, detail="HF_API_TOKEN not configured on server.")
headers = {
"Authorization": f"Bearer {HF_API_TOKEN}",
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = {"inputs": prompt, "options": {"wait_for_model": True}}
# Adjust endpoint for text-generation models
url = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
with httpx.Client(timeout=60) as client:
r = client.post(url, headers=headers, json=payload)
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"HuggingFace API error: {r.status_code} {r.text}")
return r.json()
def call_local_model(prompt: str):
global local_pipeline
if local_pipeline is None:
# lazy import to avoid heavy imports when not used
from transformers import pipeline
local_pipeline = pipeline("text-generation", model=MODEL_ID, max_length=512)
outputs = local_pipeline(prompt, do_sample=True, top_p=0.95, num_return_sequences=1)
return outputs[0]["generated_text"]
def build_prompt_from_messages(messages):
# Simple formatting — adapt for your model's preferred system/user/assistant format
prompt = ""
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
if role == "system":
prompt += f"[SYSTEM]: {content}\n"
elif role == "user":
prompt += f"User: {content}\n"
else:
prompt += f"Assistant: {content}\n"
prompt += "\nAssistant:" # instruct model to produce assistant reply
return prompt
# --- Simple connector: fetch a URL using given auth (ONLY use if you have permission) ---
import urllib.parse
from httpx import BasicAuth
async def fetch_url_with_auth(url: str, fetch_auth: Optional[dict] = None):
# fetch_auth examples:
# {"type": "bearer", "token": "XYZ"}
# {"type": "basic", "username":"u", "password":"p"}
# {"type": "session_cookie", "cookie_header": "sessionid=..."}
parsed = urllib.parse.urlparse(url)
if parsed.scheme not in ("http", "https"):
raise HTTPException(status_code=400, detail="Invalid URL scheme.")
headers = {"User-Agent": "CognitoConnector/1.0"}
auth = None
cookies = None
if fetch_auth:
t = fetch_auth.get("type")
if t == "bearer":
headers["Authorization"] = f"Bearer {fetch_auth.get('token')}"
elif t == "basic":
auth = BasicAuth(fetch_auth.get("username"), fetch_auth.get("password"))
elif t == "session_cookie":
cookies = {}
cookie_header = fetch_auth.get("cookie_header")
# cookie_header should be like "sessionid=abc; other=..."
# WARNING: you must own the credentials
if cookie_header:
# Convert header string into dict for httpx
cookie_pairs = [c.strip() for c in cookie_header.split(";") if "=" in c]
for pair in cookie_pairs:
k, v = pair.split("=", 1)
cookies[k.strip()] = v.strip()
else:
raise HTTPException(status_code=400, detail="Unsupported fetch_auth type.")
async with httpx.AsyncClient(timeout=30) as client:
r = await client.get(url, headers=headers, auth=auth, cookies=cookies)
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream fetch failed: {r.status_code}")
return r.text
# --- End connector ---
@app.post("/chat")
async def chat(req: ChatRequest):
# If the user asked to fetch a URL first, fetch it (only with authorized creds)
fetched_text = ""
if req.fetch_url:
# SAFETY: ensure server only fetches allowed urls? You can add allowlist logic here.
fetched_text = await fetch_url_with_auth(req.fetch_url, req.fetch_auth)
# Build prompt including fetched text (if any)
prompt = build_prompt_from_messages(req.messages)
if fetched_text:
prompt = f"{prompt}\n\n---\nContext from {req.fetch_url}:\n{fetched_text}\n\n---\nRespond, using the context above where helpful.\nAssistant:"
if RUN_MODE == "inference_api":
hf_out = call_hf_inference_api(prompt)
# hf_out could be dict/list depending on model — try to extract text
if isinstance(hf_out, list) and isinstance(hf_out[0], dict):
# text-generation style
text = hf_out[0].get("generated_text", str(hf_out))
elif isinstance(hf_out, dict):
# some models return {"generated_text": "..."} or similar
text = hf_out.get("generated_text") or str(hf_out)
else:
text = str(hf_out)
else:
text = call_local_model(prompt)
return {"reply": text}
# Health
@app.get("/health")
def health():
return {"status": "ok", "mode": RUN_MODE}