|
|
|
|
|
import os |
|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import httpx |
|
|
from typing import Optional, List |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
HF_API_TOKEN = os.getenv("HF_API_TOKEN") |
|
|
MODEL_ID = os.getenv("HF_MODEL_ID", "gpt2") |
|
|
RUN_MODE = os.getenv("RUN_MODE", "inference_api") |
|
|
|
|
|
|
|
|
local_pipeline = None |
|
|
|
|
|
app = FastAPI(title="Cognito - Chat API") |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
session_id: Optional[str] = None |
|
|
messages: List[dict] |
|
|
fetch_url: Optional[str] = None |
|
|
fetch_auth: Optional[dict] = None |
|
|
|
|
|
def call_hf_inference_api(prompt: str): |
|
|
if not HF_API_TOKEN: |
|
|
raise HTTPException(status_code=500, detail="HF_API_TOKEN not configured on server.") |
|
|
headers = { |
|
|
"Authorization": f"Bearer {HF_API_TOKEN}", |
|
|
"Accept": "application/json", |
|
|
"Content-Type": "application/json", |
|
|
} |
|
|
payload = {"inputs": prompt, "options": {"wait_for_model": True}} |
|
|
|
|
|
url = f"https://api-inference.huggingface.co/models/{MODEL_ID}" |
|
|
with httpx.Client(timeout=60) as client: |
|
|
r = client.post(url, headers=headers, json=payload) |
|
|
if r.status_code >= 400: |
|
|
raise HTTPException(status_code=502, detail=f"HuggingFace API error: {r.status_code} {r.text}") |
|
|
return r.json() |
|
|
|
|
|
def call_local_model(prompt: str): |
|
|
global local_pipeline |
|
|
if local_pipeline is None: |
|
|
|
|
|
from transformers import pipeline |
|
|
local_pipeline = pipeline("text-generation", model=MODEL_ID, max_length=512) |
|
|
outputs = local_pipeline(prompt, do_sample=True, top_p=0.95, num_return_sequences=1) |
|
|
return outputs[0]["generated_text"] |
|
|
|
|
|
def build_prompt_from_messages(messages): |
|
|
|
|
|
prompt = "" |
|
|
for m in messages: |
|
|
role = m.get("role", "user") |
|
|
content = m.get("content", "") |
|
|
if role == "system": |
|
|
prompt += f"[SYSTEM]: {content}\n" |
|
|
elif role == "user": |
|
|
prompt += f"User: {content}\n" |
|
|
else: |
|
|
prompt += f"Assistant: {content}\n" |
|
|
prompt += "\nAssistant:" |
|
|
return prompt |
|
|
|
|
|
|
|
|
import urllib.parse |
|
|
from httpx import BasicAuth |
|
|
|
|
|
async def fetch_url_with_auth(url: str, fetch_auth: Optional[dict] = None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parsed = urllib.parse.urlparse(url) |
|
|
if parsed.scheme not in ("http", "https"): |
|
|
raise HTTPException(status_code=400, detail="Invalid URL scheme.") |
|
|
headers = {"User-Agent": "CognitoConnector/1.0"} |
|
|
auth = None |
|
|
cookies = None |
|
|
if fetch_auth: |
|
|
t = fetch_auth.get("type") |
|
|
if t == "bearer": |
|
|
headers["Authorization"] = f"Bearer {fetch_auth.get('token')}" |
|
|
elif t == "basic": |
|
|
auth = BasicAuth(fetch_auth.get("username"), fetch_auth.get("password")) |
|
|
elif t == "session_cookie": |
|
|
cookies = {} |
|
|
cookie_header = fetch_auth.get("cookie_header") |
|
|
|
|
|
|
|
|
if cookie_header: |
|
|
|
|
|
cookie_pairs = [c.strip() for c in cookie_header.split(";") if "=" in c] |
|
|
for pair in cookie_pairs: |
|
|
k, v = pair.split("=", 1) |
|
|
cookies[k.strip()] = v.strip() |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Unsupported fetch_auth type.") |
|
|
async with httpx.AsyncClient(timeout=30) as client: |
|
|
r = await client.get(url, headers=headers, auth=auth, cookies=cookies) |
|
|
if r.status_code >= 400: |
|
|
raise HTTPException(status_code=502, detail=f"Upstream fetch failed: {r.status_code}") |
|
|
return r.text |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/chat") |
|
|
async def chat(req: ChatRequest): |
|
|
|
|
|
fetched_text = "" |
|
|
if req.fetch_url: |
|
|
|
|
|
fetched_text = await fetch_url_with_auth(req.fetch_url, req.fetch_auth) |
|
|
|
|
|
|
|
|
prompt = build_prompt_from_messages(req.messages) |
|
|
if fetched_text: |
|
|
prompt = f"{prompt}\n\n---\nContext from {req.fetch_url}:\n{fetched_text}\n\n---\nRespond, using the context above where helpful.\nAssistant:" |
|
|
|
|
|
if RUN_MODE == "inference_api": |
|
|
hf_out = call_hf_inference_api(prompt) |
|
|
|
|
|
if isinstance(hf_out, list) and isinstance(hf_out[0], dict): |
|
|
|
|
|
text = hf_out[0].get("generated_text", str(hf_out)) |
|
|
elif isinstance(hf_out, dict): |
|
|
|
|
|
text = hf_out.get("generated_text") or str(hf_out) |
|
|
else: |
|
|
text = str(hf_out) |
|
|
else: |
|
|
text = call_local_model(prompt) |
|
|
|
|
|
return {"reply": text} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return {"status": "ok", "mode": RUN_MODE} |
|
|
|