alaselababatunde commited on
Commit
13d2b5e
·
1 Parent(s): 8351aff
Files changed (1) hide show
  1. main.py +64 -31
main.py CHANGED
@@ -4,78 +4,111 @@ from transformers import pipeline
4
  from langchain.llms import HuggingFacePipeline
5
  from langchain.chains import LLMChain
6
  from langchain.prompts import PromptTemplate
 
7
  import torch
8
  import logging
9
 
10
- # ===== CONFIG =====
 
 
11
  API_SECRET = "techdisciplesai404"
12
  MODEL_NAME = "google/flan-t5-large"
13
  DEVICE = 0 if torch.cuda.is_available() else -1
14
 
15
- # ===== LOGGING =====
 
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger("TechDisciplesAI")
18
 
19
- # ===== INITIALIZE APP =====
20
- app = FastAPI(title="TechDisciples AI", version="2.0")
 
 
21
 
22
- # ===== MODEL SETUP =====
 
 
23
  try:
24
  logger.info(f"🚀 Loading model: {MODEL_NAME}")
25
- pipe = pipeline(
 
26
  "text2text-generation",
27
  model=MODEL_NAME,
28
  device=DEVICE,
29
  max_new_tokens=256,
30
  temperature=0.3,
31
- do_sample=True
 
32
  )
33
 
34
- llm = HuggingFacePipeline(pipeline=pipe)
35
  logger.info("✅ Model loaded successfully.")
 
36
  except Exception as e:
37
  logger.error(f"❌ Failed to load model: {e}")
38
  llm = None
39
 
40
- # ===== PROMPT TEMPLATE =====
 
 
 
 
 
 
 
41
  prompt_template = """
42
- You are a Christian conversational AI named TechDisciples AI.
43
- Answer the question naturally and clearly, providing biblical or inspirational insight where possible.
44
 
45
- Question: {query}
 
46
 
47
- Response:
 
48
  """
49
 
50
- prompt = PromptTemplate(template=prompt_template, input_variables=["query"])
51
- chain = LLMChain(prompt=prompt, llm=llm)
52
-
53
- # ===== REQUEST MODEL =====
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  class QueryInput(BaseModel):
55
  query: str
 
56
 
 
 
 
 
 
 
57
 
58
- # ===== ROUTES =====
59
  @app.post("/ai-chat")
60
  async def ai_chat(data: QueryInput, x_api_key: str = Header(None)):
 
61
  if x_api_key != API_SECRET:
62
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key")
63
 
64
- if llm is None:
65
  raise HTTPException(status_code=500, detail="Model not initialized")
66
 
67
- user_query = data.query.strip()
68
- if not user_query:
69
- raise HTTPException(status_code=400, detail="Query cannot be empty")
70
-
71
  try:
72
- response = chain.run(query=user_query)
73
- return {"reply": response.strip(), "tone_used": "neutral"}
74
  except Exception as e:
75
- logger.error(f"⚠️ Generation error: {e}")
76
  raise HTTPException(status_code=500, detail="Model failed to respond")
77
-
78
-
79
- @app.get("/")
80
- async def root():
81
- return {"message": "✅ TechDisciples AI (LangChain) is running."}
 
4
  from langchain.llms import HuggingFacePipeline
5
  from langchain.chains import LLMChain
6
  from langchain.prompts import PromptTemplate
7
+ from langchain.memory import ConversationBufferMemory
8
  import torch
9
  import logging
10
 
11
+ # ===============================================
12
+ # CONFIGURATION
13
+ # ===============================================
14
  API_SECRET = "techdisciplesai404"
15
  MODEL_NAME = "google/flan-t5-large"
16
  DEVICE = 0 if torch.cuda.is_available() else -1
17
 
18
+ # ===============================================
19
+ # LOGGING SETUP
20
+ # ===============================================
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger("TechDisciplesAI")
23
 
24
+ # ===============================================
25
+ # FASTAPI APP
26
+ # ===============================================
27
+ app = FastAPI(title="Tech Disciples AI (LangChain Conversational)", version="3.0")
28
 
29
+ # ===============================================
30
+ # LOAD MODEL USING PIPELINE + LANGCHAIN
31
+ # ===============================================
32
  try:
33
  logger.info(f"🚀 Loading model: {MODEL_NAME}")
34
+
35
+ hf_pipeline = pipeline(
36
  "text2text-generation",
37
  model=MODEL_NAME,
38
  device=DEVICE,
39
  max_new_tokens=256,
40
  temperature=0.3,
41
+ do_sample=True,
42
+ top_p=0.9
43
  )
44
 
45
+ llm = HuggingFacePipeline(pipeline=hf_pipeline)
46
  logger.info("✅ Model loaded successfully.")
47
+
48
  except Exception as e:
49
  logger.error(f"❌ Failed to load model: {e}")
50
  llm = None
51
 
52
+ # ===============================================
53
+ # MEMORY SYSTEM
54
+ # ===============================================
55
+ memory = ConversationBufferMemory(memory_key="conversation_history")
56
+
57
+ # ===============================================
58
+ # PROMPT TEMPLATE
59
+ # ===============================================
60
  prompt_template = """
61
+ You are Tech Disciples AI — a spiritually aware, intelligent, and kind conversational assistant.
62
+ You offer thoughtful, biblical, and insightful answers with grace, empathy, and calm intelligence.
63
 
64
+ Conversation so far:
65
+ {conversation_history}
66
 
67
+ User: {query}
68
+ Tech Disciples AI:
69
  """
70
 
71
+ prompt = PromptTemplate(
72
+ template=prompt_template,
73
+ input_variables=["conversation_history", "query"]
74
+ )
75
+
76
+ # ===============================================
77
+ # LLM CHAIN (with memory)
78
+ # ===============================================
79
+ chain = LLMChain(
80
+ prompt=prompt,
81
+ llm=llm,
82
+ memory=memory
83
+ )
84
+
85
+ # ===============================================
86
+ # REQUEST MODEL
87
+ # ===============================================
88
  class QueryInput(BaseModel):
89
  query: str
90
+ session_id: str | None = "default" # optional: could be user/session-based
91
 
92
+ # ===============================================
93
+ # ROUTES
94
+ # ===============================================
95
+ @app.get("/")
96
+ async def root():
97
+ return {"message": "✅ Tech Disciples AI (LangChain Memory) is running."}
98
 
 
99
  @app.post("/ai-chat")
100
  async def ai_chat(data: QueryInput, x_api_key: str = Header(None)):
101
+ # --- Authentication ---
102
  if x_api_key != API_SECRET:
103
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key")
104
 
105
+ if not llm:
106
  raise HTTPException(status_code=500, detail="Model not initialized")
107
 
108
+ # --- Process Query ---
 
 
 
109
  try:
110
+ response = chain.run(query=data.query.strip())
111
+ return {"reply": response.strip()}
112
  except Exception as e:
113
+ logger.error(f"⚠️ Model error: {e}")
114
  raise HTTPException(status_code=500, detail="Model failed to respond")