Krish-05 commited on
Commit
20e16a7
·
verified ·
1 Parent(s): 1cbe61b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +6 -2
main.py CHANGED
@@ -10,10 +10,12 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- app = FastAPI() # <--- THIS LINE IS CRUCIAL
14
  MODEL_NAME = 'krishna_choudhary/AI_Assistant_Chatbot' # Your specific model name
15
 
16
  def get_llm():
 
 
17
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
18
  return Ollama(model=MODEL_NAME, callback_manager=callback_manager)
19
 
@@ -28,6 +30,8 @@ def read_root():
28
  async def ask_question(question: Question):
29
  try:
30
  llm = get_llm()
 
 
31
  response = llm.invoke(question.text)
32
  return {"response": response}
33
  except Exception as e:
@@ -41,4 +45,4 @@ async def startup_event():
41
 
42
  @app.on_event("shutdown")
43
  async def shutdown_event():
44
- logger.info("Shutting down")
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ app = FastAPI()
14
  MODEL_NAME = 'krishna_choudhary/AI_Assistant_Chatbot' # Your specific model name
15
 
16
  def get_llm():
17
+ # Note: StreamingStdOutCallbackHandler is for console output, not for API streaming.
18
+ # For actual API streaming, you'd integrate a custom callback that yields chunks.
19
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
20
  return Ollama(model=MODEL_NAME, callback_manager=callback_manager)
21
 
 
30
  async def ask_question(question: Question):
31
  try:
32
  llm = get_llm()
33
+ # For simplicity, we're using invoke() which returns the full response at once.
34
+ # If you need true streaming to the client, you'd use llm.stream() and a custom StreamingResponse.
35
  response = llm.invoke(question.text)
36
  return {"response": response}
37
  except Exception as e:
 
45
 
46
  @app.on_event("shutdown")
47
  async def shutdown_event():
48
+ logger.info("Shutting down")