File size: 4,387 Bytes
b0ee4b2
 
 
 
 
 
 
 
 
834285f
b0ee4b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
834285f
b0ee4b2
 
 
834285f
b0ee4b2
be0a4ce
b0ee4b2
 
 
834285f
be0a4ce
834285f
 
 
 
 
 
 
 
 
 
be0a4ce
834285f
 
 
 
 
 
 
 
 
 
 
 
be0a4ce
834285f
 
 
be0a4ce
834285f
 
be0a4ce
834285f
 
 
 
 
 
 
 
 
 
 
 
 
be0a4ce
 
 
834285f
 
 
be0a4ce
 
 
834285f
 
 
 
 
 
b0ee4b2
be0a4ce
b0ee4b2
 
 
834285f
b0ee4b2
 
be0a4ce
b0ee4b2
 
834285f
b0ee4b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import logging
import json
import httpx
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

app = FastAPI()

# Optional: Configure CORS if needed
origins = [
    # Add allowed origins if you implement a frontend later
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,  # Adjust as needed
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load your API key from the environment (defaults to "change_me")
API_KEY = os.environ.get("API_KEY", "change_me")
logger.debug(f"Loaded API key: {API_KEY}")

# URL of the running Ollama server (adjust as needed)
OLLAMA_SERVER_URL = "http://localhost:11434/api/generate"
logger.debug(f"Ollama server URL: {OLLAMA_SERVER_URL}")


@app.post("/api/generate")
async def generate(request: Request):
    """Endpoint that generates text based on the prompt."""
    try:
        # Parse the incoming request
        body = await request.json()
        model = body.get("model", "hf.co/abanm/Dubs-Q8_0-GGUF:latest")  # Default model
        prompt_text = body.get("prompt", "")

        if not prompt_text:
            logger.error("No prompt provided in the request")
            raise HTTPException(status_code=400, detail="No prompt provided")

        logger.debug(f"Request body: {body}")

        # Validate API key
        auth_header = request.headers.get("Authorization")
        logger.debug(f"Received Authorization header: {auth_header}")

        if not auth_header or not auth_header.startswith("Bearer "):
            logger.error("Missing or invalid Authorization header")
            raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")

        token = auth_header.split(" ")[1]
        if token != API_KEY:
            logger.error(f"Invalid API key provided: {token}")
            raise HTTPException(status_code=401, detail="Invalid API key")

        # Prepare request payload
        payload = {"model": model, "prompt": prompt_text}
        logger.debug(f"Payload prepared for Ollama: {payload}")

        # Stream response from Ollama
        async def stream_response():
            try:
                async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
                    async with client.stream(
                        "POST", OLLAMA_SERVER_URL, json=payload, headers={"Content-Type": "application/json"}
                    ) as response:
                        logger.info(f"Response status code from Ollama: {response.status_code}")

                        if response.status_code != 200:
                            logger.error(f"HTTP error: {response.status_code} - {await response.text()}")
                            yield json.dumps({"error": f"HTTP error: {response.status_code}"})
                            return

                        async for chunk in response.aiter_text():
                            logger.debug(f"Chunk received: {chunk}")
                            yield chunk
            except httpx.ReadTimeout:
                logger.error("ReadTimeout while waiting for response chunks")
                yield json.dumps({"error": "Server response timeout. Try again later."})
            except httpx.RequestError as exc:
                logger.error(f"Request error while communicating with Ollama: {str(exc)}")
                yield json.dumps({"error": "Network error occurred while communicating with Ollama"})
            except Exception as exc:
                logger.exception(f"Unexpected error during streaming: {str(exc)}")
                yield json.dumps({"error": "An unexpected error occurred during streaming."})

        return StreamingResponse(stream_response(), media_type="application/json")

    except Exception as e:
        logger.exception(f"Unexpected error: {str(e)}")
        raise HTTPException(status_code=500, detail="An unexpected error occurred")


@app.get("/health")
async def health():
    """Health check endpoint."""
    logger.info("Health check endpoint accessed")
    return {"status": "OK"}


if __name__ == "__main__":
    import uvicorn
    logger.info("Starting FastAPI application")
    uvicorn.run(app, host="0.0.0.0", port=7860)