alaselababatunde commited on
Commit
75aa43a
·
1 Parent(s): 4ea2add
Files changed (1) hide show
  1. main.py +65 -119
main.py CHANGED
@@ -1,135 +1,81 @@
1
- # ==============================================================
2
- # Tech Disciples AI Backend — Free-Response Edition (Flan-T5 Fixed)
3
- # ==============================================================
4
-
5
- import os
6
- import logging
7
- import torch
8
- from fastapi import FastAPI, Header, HTTPException
9
- from fastapi.responses import JSONResponse
10
  from pydantic import BaseModel
11
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
12
-
13
- # --------------------------------------------------------------
14
- # Logging Setup
15
- # --------------------------------------------------------------
16
- logging.basicConfig(
17
- level=logging.INFO,
18
- format="%(asctime)s [%(levelname)s] %(message)s"
19
- )
20
- logger = logging.getLogger("Tech Disciples AI")
21
-
22
- # --------------------------------------------------------------
23
- # FastAPI App
24
- # --------------------------------------------------------------
25
- app = FastAPI(title="Tech Disciples AI")
26
-
27
- @app.get("/")
28
- async def root():
29
- return {"status": "✅ Tech Disciples AI Backend is active and ready."}
30
-
31
- # --------------------------------------------------------------
32
- # Authentication
33
- # --------------------------------------------------------------
34
- PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "techdisciplesai404")
35
-
36
- def check_auth(authorization: str | None):
37
- if not PROJECT_API_KEY:
38
- return
39
- if not authorization or not authorization.startswith("Bearer "):
40
- raise HTTPException(status_code=401, detail="Missing bearer token")
41
- token = authorization.split(" ", 1)[1]
42
- if token != PROJECT_API_KEY:
43
- raise HTTPException(status_code=403, detail="Invalid token")
44
-
45
- # --------------------------------------------------------------
46
- # Error Handler
47
- # --------------------------------------------------------------
48
- @app.exception_handler(Exception)
49
- async def global_exception_handler(request, exc):
50
- logger.exception("Unhandled error:")
51
- return JSONResponse(status_code=500, content={"error": str(exc)})
52
-
53
- # --------------------------------------------------------------
54
- # Request Model
55
- # --------------------------------------------------------------
56
- class ChatRequest(BaseModel):
57
- query: str
58
-
59
- # --------------------------------------------------------------
60
- # System Prompt
61
- # --------------------------------------------------------------
62
- SYSTEM_TEMPLATE = """You are Tech Disciples AI — a spiritually aware, intelligent, and kind conversational assistant.
63
- You respond clearly and truthfully, offering thoughtful, biblically grounded, and intelligent answers.
64
- Avoid debates or judgmental tones, and always reply in full sentences with understanding and grace.
65
- """
66
 
67
- # --------------------------------------------------------------
68
- # Model Load
69
- # --------------------------------------------------------------
70
  MODEL_NAME = "google/flan-t5-large"
71
- HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
72
 
73
- device = 0 if torch.cuda.is_available() else -1
74
- logger.info(f"🧠 Using device: {'GPU' if device == 0 else 'CPU'}")
 
75
 
76
- chat_pipe = None
 
 
 
77
  try:
78
  logger.info(f"🚀 Loading model: {MODEL_NAME}")
79
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN if HF_TOKEN else None)
80
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, token=HF_TOKEN if HF_TOKEN else None)
81
-
82
- chat_pipe = pipeline(
83
  "text2text-generation",
84
- model=model,
85
- tokenizer=tokenizer,
86
- device=device
 
 
87
  )
88
- logger.info("✅ Model pipeline initialized successfully.")
 
 
89
  except Exception as e:
90
- chat_pipe = None
91
- logger.error(f"❌ Failed to initialize model pipeline: {e}")
92
 
93
- # --------------------------------------------------------------
94
- # Helper: Generate Reply
95
- # --------------------------------------------------------------
96
- def generate_reply(pipe, user_input: str) -> str:
97
- if not pipe:
98
- return "⚠️ Model pipeline not initialized."
99
 
100
- prompt = f"{SYSTEM_TEMPLATE}\nUser: {user_input}\nTech Disciples AI:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  try:
103
- result = pipe(
104
- prompt,
105
- do_sample=True,
106
- temperature=0.3,
107
- top_p=0.9,
108
- repetition_penalty=1.05
109
- )
110
-
111
- if not result or "generated_text" not in result[0]:
112
- return "⚠️ Model returned no content."
113
-
114
- text = result[0]["generated_text"]
115
- if text.startswith(prompt):
116
- text = text[len(prompt):]
117
- text = text.replace("User:", "").replace("Tech Disciples AI:", "").strip()
118
-
119
- return text if text else "⚠️ No valid reply generated."
120
  except Exception as e:
121
- logger.exception("Error generating reply:")
122
- return f"⚠️ Model error: {e}"
123
 
124
- # --------------------------------------------------------------
125
- # Endpoint
126
- # --------------------------------------------------------------
127
- @app.post("/ai-chat")
128
- async def ai_chat(req: ChatRequest, authorization: str | None = Header(None)):
129
- check_auth(authorization)
130
- reply = generate_reply(chat_pipe, req.query)
131
- return {"reply": reply}
132
-
133
- # ==============================================================
134
- # END OF FILE
135
- # ==============================================================
 
1
+ from fastapi import FastAPI, HTTPException, Header
 
 
 
 
 
 
 
 
2
  from pydantic import BaseModel
3
+ from transformers import pipeline
4
+ from langchain.llms import HuggingFacePipeline
5
+ from langchain.chains import LLMChain
6
+ from langchain.prompts import PromptTemplate
7
+ import torch
8
+ import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # ===== CONFIG =====
11
+ API_SECRET = "techdisciplesai404"
 
12
  MODEL_NAME = "google/flan-t5-large"
13
+ DEVICE = 0 if torch.cuda.is_available() else -1
14
 
15
+ # ===== LOGGING =====
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger("TechDisciplesAI")
18
 
19
+ # ===== INITIALIZE APP =====
20
+ app = FastAPI(title="TechDisciples AI", version="2.0")
21
+
22
+ # ===== MODEL SETUP =====
23
  try:
24
  logger.info(f"🚀 Loading model: {MODEL_NAME}")
25
+ pipe = pipeline(
 
 
 
26
  "text2text-generation",
27
+ model=MODEL_NAME,
28
+ device=DEVICE,
29
+ max_new_tokens=256,
30
+ temperature=0.3,
31
+ do_sample=True
32
  )
33
+
34
+ llm = HuggingFacePipeline(pipeline=pipe)
35
+ logger.info("✅ Model loaded successfully.")
36
  except Exception as e:
37
+ logger.error(f"❌ Failed to load model: {e}")
38
+ llm = None
39
 
40
+ # ===== PROMPT TEMPLATE =====
41
+ prompt_template = """
42
+ You are a Christian conversational AI named TechDisciples AI.
43
+ Answer the question naturally and clearly, providing biblical or inspirational insight where possible.
 
 
44
 
45
+ Question: {query}
46
+
47
+ Response:
48
+ """
49
+
50
+ prompt = PromptTemplate(template=prompt_template, input_variables=["query"])
51
+ chain = LLMChain(prompt=prompt, llm=llm)
52
+
53
+ # ===== REQUEST MODEL =====
54
+ class QueryInput(BaseModel):
55
+ query: str
56
+
57
+
58
+ # ===== ROUTES =====
59
+ @app.post("/ai-chat")
60
+ async def ai_chat(data: QueryInput, x_api_key: str = Header(None)):
61
+ if x_api_key != API_SECRET:
62
+ raise HTTPException(status_code=403, detail="Forbidden: Invalid API key")
63
+
64
+ if llm is None:
65
+ raise HTTPException(status_code=500, detail="Model not initialized")
66
+
67
+ user_query = data.query.strip()
68
+ if not user_query:
69
+ raise HTTPException(status_code=400, detail="Query cannot be empty")
70
 
71
  try:
72
+ response = chain.run(query=user_query)
73
+ return {"reply": response.strip(), "tone_used": "neutral"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  except Exception as e:
75
+ logger.error(f"⚠️ Generation error: {e}")
76
+ raise HTTPException(status_code=500, detail="Model failed to respond")
77
 
78
+
79
+ @app.get("/")
80
+ async def root():
81
+ return {"message": "✅ TechDisciples AI (LangChain) is running."}