marcweibel commited on
Commit
379b034
·
verified ·
1 Parent(s): 6f5c63d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -19
app.py CHANGED
@@ -1,41 +1,52 @@
1
  import logging
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  from pydantic import BaseModel
 
5
  from langchain_community.llms import Ollama
6
- from langchain.callbacks.manager import CallbackManager
7
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
 
8
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
  app = FastAPI()
13
 
14
- # Allow cross-origin requests (for browser-based JS)
15
  app.add_middleware(
16
  CORSMiddleware,
17
- allow_origins=["*"], # Adjust for production if needed
18
  allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
21
  )
22
 
23
- MODEL_NAME = 'tinyllama'
24
-
25
- def get_llm():
26
- callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
27
- return Ollama(model=MODEL_NAME, callback_manager=callback_manager)
28
-
29
- class Question(BaseModel):
30
- text: str
31
-
32
  class QueryWithContext(BaseModel):
33
  question: str
34
  context: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  @app.get("/")
37
- def read_root():
38
- return {"Hello": f"Welcome to {MODEL_NAME} FastAPI"}
39
 
40
  @app.post("/ask")
41
  def ask(query: QueryWithContext):
@@ -47,18 +58,35 @@ Lecture notes:
47
  User's question:
48
  {query.question}
49
  """
50
- llm = get_llm()
 
 
 
 
 
 
51
  try:
52
- response = llm.invoke(prompt)
53
- return {"answer": response}
 
 
 
 
 
 
 
 
 
 
54
  except Exception as e:
55
  raise HTTPException(status_code=500, detail=str(e))
56
 
57
  @app.on_event("startup")
58
  async def startup_event():
59
- logger.info(f"Starting up with model: {MODEL_NAME}")
60
 
61
  @app.on_event("shutdown")
62
  async def shutdown_event():
63
- logger.info("Shutting down")
 
64
 
 
1
  import logging
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel
6
+ from typing import Generator
7
  from langchain_community.llms import Ollama
 
8
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9
+ from langchain.callbacks.base import BaseCallbackHandler
10
+ from langchain.callbacks.manager import CallbackManager
11
 
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
  app = FastAPI()
16
 
17
+ # Enable CORS
18
  app.add_middleware(
19
  CORSMiddleware,
20
+ allow_origins=["*"], # You can restrict this in production
21
  allow_credentials=True,
22
  allow_methods=["*"],
23
  allow_headers=["*"],
24
  )
25
 
 
 
 
 
 
 
 
 
 
26
  class QueryWithContext(BaseModel):
27
  question: str
28
  context: str
29
+ model: str = "tinyllama" # default model
30
+
31
+ # Streaming callback that yields tokens
32
+ class StreamingGeneratorCallback(BaseCallbackHandler):
33
+ def __init__(self):
34
+ self.buffer = ""
35
+ self.queue = []
36
+ self.streaming = True
37
+
38
+ def on_llm_new_token(self, token: str, **kwargs):
39
+ self.queue.append(token)
40
+
41
+ def stream(self) -> Generator[str, None, None]:
42
+ while self.streaming or self.queue:
43
+ if self.queue:
44
+ chunk = self.queue.pop(0)
45
+ yield chunk
46
 
47
  @app.get("/")
48
+ def root():
49
+ return {"message": "FastAPI Ollama is running."}
50
 
51
  @app.post("/ask")
52
  def ask(query: QueryWithContext):
 
58
  User's question:
59
  {query.question}
60
  """
61
+
62
+ # Create streaming callback
63
+ stream_callback = StreamingGeneratorCallback()
64
+ callback_manager = CallbackManager([stream_callback])
65
+
66
+ llm = Ollama(model=query.model, callback_manager=callback_manager)
67
+
68
  try:
69
+ # Start generation in background
70
+ def run_llm():
71
+ try:
72
+ llm.invoke(prompt)
73
+ finally:
74
+ stream_callback.streaming = False
75
+
76
+ import threading
77
+ threading.Thread(target=run_llm).start()
78
+
79
+ return StreamingResponse(stream_callback.stream(), media_type="text/plain")
80
+
81
  except Exception as e:
82
  raise HTTPException(status_code=500, detail=str(e))
83
 
84
  @app.on_event("startup")
85
  async def startup_event():
86
+ logger.info("FastAPI is starting up...")
87
 
88
  @app.on_event("shutdown")
89
  async def shutdown_event():
90
+ logger.info("FastAPI is shutting down.")
91
+
92