# app.py import os from fastapi import FastAPI, Request, HTTPException from pydantic import BaseModel import httpx from typing import Optional, List from dotenv import load_dotenv load_dotenv() HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Put your Hugging Face token here if using HF Inference API MODEL_ID = os.getenv("HF_MODEL_ID", "gpt2") # default, replace with your Cognito model id RUN_MODE = os.getenv("RUN_MODE", "inference_api") # or "local" # If RUN_MODE == "local", we will lazily import transformers and create a pipeline local_pipeline = None app = FastAPI(title="Cognito - Chat API") class ChatRequest(BaseModel): session_id: Optional[str] = None messages: List[dict] # [{"role":"user","content":"..."} , ...] fetch_url: Optional[str] = None # optional: server will fetch a URL (only if connector authorized) fetch_auth: Optional[dict] = None # optional auth info for connector (see docs) def call_hf_inference_api(prompt: str): if not HF_API_TOKEN: raise HTTPException(status_code=500, detail="HF_API_TOKEN not configured on server.") headers = { "Authorization": f"Bearer {HF_API_TOKEN}", "Accept": "application/json", "Content-Type": "application/json", } payload = {"inputs": prompt, "options": {"wait_for_model": True}} # Adjust endpoint for text-generation models url = f"https://api-inference.huggingface.co/models/{MODEL_ID}" with httpx.Client(timeout=60) as client: r = client.post(url, headers=headers, json=payload) if r.status_code >= 400: raise HTTPException(status_code=502, detail=f"HuggingFace API error: {r.status_code} {r.text}") return r.json() def call_local_model(prompt: str): global local_pipeline if local_pipeline is None: # lazy import to avoid heavy imports when not used from transformers import pipeline local_pipeline = pipeline("text-generation", model=MODEL_ID, max_length=512) outputs = local_pipeline(prompt, do_sample=True, top_p=0.95, num_return_sequences=1) return outputs[0]["generated_text"] def build_prompt_from_messages(messages): # Simple formatting — adapt for your model's preferred system/user/assistant format prompt = "" for m in messages: role = m.get("role", "user") content = m.get("content", "") if role == "system": prompt += f"[SYSTEM]: {content}\n" elif role == "user": prompt += f"User: {content}\n" else: prompt += f"Assistant: {content}\n" prompt += "\nAssistant:" # instruct model to produce assistant reply return prompt # --- Simple connector: fetch a URL using given auth (ONLY use if you have permission) --- import urllib.parse from httpx import BasicAuth async def fetch_url_with_auth(url: str, fetch_auth: Optional[dict] = None): # fetch_auth examples: # {"type": "bearer", "token": "XYZ"} # {"type": "basic", "username":"u", "password":"p"} # {"type": "session_cookie", "cookie_header": "sessionid=..."} parsed = urllib.parse.urlparse(url) if parsed.scheme not in ("http", "https"): raise HTTPException(status_code=400, detail="Invalid URL scheme.") headers = {"User-Agent": "CognitoConnector/1.0"} auth = None cookies = None if fetch_auth: t = fetch_auth.get("type") if t == "bearer": headers["Authorization"] = f"Bearer {fetch_auth.get('token')}" elif t == "basic": auth = BasicAuth(fetch_auth.get("username"), fetch_auth.get("password")) elif t == "session_cookie": cookies = {} cookie_header = fetch_auth.get("cookie_header") # cookie_header should be like "sessionid=abc; other=..." # WARNING: you must own the credentials if cookie_header: # Convert header string into dict for httpx cookie_pairs = [c.strip() for c in cookie_header.split(";") if "=" in c] for pair in cookie_pairs: k, v = pair.split("=", 1) cookies[k.strip()] = v.strip() else: raise HTTPException(status_code=400, detail="Unsupported fetch_auth type.") async with httpx.AsyncClient(timeout=30) as client: r = await client.get(url, headers=headers, auth=auth, cookies=cookies) if r.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upstream fetch failed: {r.status_code}") return r.text # --- End connector --- @app.post("/chat") async def chat(req: ChatRequest): # If the user asked to fetch a URL first, fetch it (only with authorized creds) fetched_text = "" if req.fetch_url: # SAFETY: ensure server only fetches allowed urls? You can add allowlist logic here. fetched_text = await fetch_url_with_auth(req.fetch_url, req.fetch_auth) # Build prompt including fetched text (if any) prompt = build_prompt_from_messages(req.messages) if fetched_text: prompt = f"{prompt}\n\n---\nContext from {req.fetch_url}:\n{fetched_text}\n\n---\nRespond, using the context above where helpful.\nAssistant:" if RUN_MODE == "inference_api": hf_out = call_hf_inference_api(prompt) # hf_out could be dict/list depending on model — try to extract text if isinstance(hf_out, list) and isinstance(hf_out[0], dict): # text-generation style text = hf_out[0].get("generated_text", str(hf_out)) elif isinstance(hf_out, dict): # some models return {"generated_text": "..."} or similar text = hf_out.get("generated_text") or str(hf_out) else: text = str(hf_out) else: text = call_local_model(prompt) return {"reply": text} # Health @app.get("/health") def health(): return {"status": "ok", "mode": RUN_MODE}