alaselababatunde commited on
Commit
0d65d98
·
1 Parent(s): b353ea6
Files changed (2) hide show
  1. main.py +34 -10
  2. requirements.txt +1 -2
main.py CHANGED
@@ -1,18 +1,22 @@
 
 
 
 
1
  from fastapi import FastAPI, HTTPException, Header
2
  from pydantic import BaseModel
3
  import torch
4
  import logging
5
 
6
- # LangChain imports (stable >=1.0)
7
- from langchain.chains import LLMChain
8
  from langchain.prompts import PromptTemplate
9
  from langchain.memory import ConversationBufferMemory
10
- from langchain.llms.base import LLM # For custom LLM wrappers
 
 
11
 
12
  # Transformers pipeline
13
  from transformers import pipeline
14
 
15
-
16
  # ===============================================
17
  # CONFIGURATION
18
  # ===============================================
@@ -32,7 +36,7 @@ logger = logging.getLogger("TechDisciplesAI")
32
  app = FastAPI(title="Tech Disciples AI (LangChain Conversational)", version="3.0")
33
 
34
  # ===============================================
35
- # LOAD MODEL USING PIPELINE + LANGCHAIN
36
  # ===============================================
37
  try:
38
  logger.info(f"🚀 Loading model: {MODEL_NAME}")
@@ -47,13 +51,31 @@ try:
47
  top_p=0.9
48
  )
49
 
50
- # You can later wrap hf_pipeline in a custom LLM class compatible with LLMChain
51
- llm = hf_pipeline
52
- logger.info("✅ Model loaded successfully.")
53
 
54
  except Exception as e:
55
- logger.error(f"❌ Failed to load model: {e}")
56
- llm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # ===============================================
59
  # MEMORY SYSTEM
@@ -104,12 +126,14 @@ async def root():
104
 
105
  @app.post("/ai-chat")
106
  async def ai_chat(data: QueryInput, x_api_key: str = Header(None)):
 
107
  if x_api_key != API_SECRET:
108
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key")
109
 
110
  if not llm:
111
  raise HTTPException(status_code=500, detail="Model not initialized")
112
 
 
113
  try:
114
  response = chain.run(query=data.query.strip())
115
  return {"reply": response.strip()}
 
1
+ # ===============================================
2
+ # Tech Disciples AI Backend — Main.py
3
+ # ===============================================
4
+
5
  from fastapi import FastAPI, HTTPException, Header
6
  from pydantic import BaseModel
7
  import torch
8
  import logging
9
 
10
+ # LangChain 1.0 imports
 
11
  from langchain.prompts import PromptTemplate
12
  from langchain.memory import ConversationBufferMemory
13
+ from langchain.chains import LLMChain
14
+ from langchain.llms.base import LLM
15
+ from typing import Optional, List
16
 
17
  # Transformers pipeline
18
  from transformers import pipeline
19
 
 
20
  # ===============================================
21
  # CONFIGURATION
22
  # ===============================================
 
36
  app = FastAPI(title="Tech Disciples AI (LangChain Conversational)", version="3.0")
37
 
38
  # ===============================================
39
+ # HUGGING FACE PIPELINE
40
  # ===============================================
41
  try:
42
  logger.info(f"🚀 Loading model: {MODEL_NAME}")
 
51
  top_p=0.9
52
  )
53
 
54
+ logger.info("✅ Hugging Face pipeline loaded successfully.")
 
 
55
 
56
  except Exception as e:
57
+ logger.error(f"❌ Failed to load Hugging Face pipeline: {e}")
58
+ hf_pipeline = None
59
+
60
+ # ===============================================
61
+ # HUGGING FACE LLM WRAPPER FOR LANGCHAIN
62
+ # ===============================================
63
+ class HFLLMWrapper(LLM):
64
+ def __init__(self, pipeline):
65
+ self.pipeline = pipeline
66
+
67
+ @property
68
+ def _llm_type(self) -> str:
69
+ return "hf_pipeline"
70
+
71
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
72
+ output = self.pipeline(prompt)
73
+ if isinstance(output, list) and len(output) > 0:
74
+ return output[0].get("generated_text", str(output[0]))
75
+ return str(output)
76
+
77
+ # Initialize LLM wrapper
78
+ llm = HFLLMWrapper(hf_pipeline) if hf_pipeline else None
79
 
80
  # ===============================================
81
  # MEMORY SYSTEM
 
126
 
127
  @app.post("/ai-chat")
128
  async def ai_chat(data: QueryInput, x_api_key: str = Header(None)):
129
+ # --- Authentication ---
130
  if x_api_key != API_SECRET:
131
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key")
132
 
133
  if not llm:
134
  raise HTTPException(status_code=500, detail="Model not initialized")
135
 
136
+ # --- Process Query ---
137
  try:
138
  response = chain.run(query=data.query.strip())
139
  return {"reply": response.strip()}
requirements.txt CHANGED
@@ -3,7 +3,6 @@ uvicorn[standard]
3
  torch
4
  transformers
5
  accelerate
6
- langchain==1.0.0
7
  huggingface-hub
8
  pydantic
9
- python-multipart
 
3
  torch
4
  transformers
5
  accelerate
6
+ langchain>=1.0
7
  huggingface-hub
8
  pydantic