File size: 2,543 Bytes
c284cbf
 
6f5c63d
379b034
c284cbf
379b034
c284cbf
 
379b034
 
c284cbf
 
 
 
 
6f5c63d
379b034
6f5c63d
 
379b034
6f5c63d
 
 
 
 
 
 
 
379b034
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5c63d
c284cbf
379b034
 
c284cbf
6f5c63d
 
 
 
 
 
 
 
 
 
379b034
 
 
 
 
 
 
6f5c63d
379b034
 
 
 
 
 
 
 
 
 
 
 
6f5c63d
 
 
c284cbf
 
379b034
c284cbf
 
 
379b034
 
6f5c63d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import logging
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Generator
from langchain_community.llms import Ollama
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import CallbackManager

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # You can restrict this in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class QueryWithContext(BaseModel):
    question: str
    context: str
    model: str = "tinyllama"  # default model

# Streaming callback that yields tokens
class StreamingGeneratorCallback(BaseCallbackHandler):
    def __init__(self):
        self.buffer = ""
        self.queue = []
        self.streaming = True

    def on_llm_new_token(self, token: str, **kwargs):
        self.queue.append(token)

    def stream(self) -> Generator[str, None, None]:
        while self.streaming or self.queue:
            if self.queue:
                chunk = self.queue.pop(0)
                yield chunk

@app.get("/")
def root():
    return {"message": "FastAPI Ollama is running."}

@app.post("/ask")
def ask(query: QueryWithContext):
    prompt = f"""You are an expert in quantitative methods. Based on the following lecture notes, answer the user's question.

Lecture notes:
{query.context}

User's question:
{query.question}
"""

    # Create streaming callback
    stream_callback = StreamingGeneratorCallback()
    callback_manager = CallbackManager([stream_callback])
    
    llm = Ollama(model=query.model, callback_manager=callback_manager)

    try:
        # Start generation in background
        def run_llm():
            try:
                llm.invoke(prompt)
            finally:
                stream_callback.streaming = False

        import threading
        threading.Thread(target=run_llm).start()

        return StreamingResponse(stream_callback.stream(), media_type="text/plain")
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.on_event("startup")
async def startup_event():
    logger.info("FastAPI is starting up...")

@app.on_event("shutdown")
async def shutdown_event():
    logger.info("FastAPI is shutting down.")