alaselababatunde commited on
Commit
21cb694
·
1 Parent(s): bec61fa
Files changed (1) hide show
  1. main.py +44 -43
main.py CHANGED
@@ -8,23 +8,19 @@ import torch
8
  import logging
9
  import os
10
 
11
- # Hugging Face Hub
12
  from huggingface_hub import login
13
-
14
- # LangChain
15
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
16
  from langchain.chains import LLMChain
17
  from langchain.prompts.prompt import PromptTemplate
18
  from langchain.memory import ConversationBufferMemory
19
-
20
- # Transformers pipeline
21
  from transformers import pipeline
22
 
23
  # =====================================================
24
  # CONFIGURATION
25
  # =====================================================
26
  API_SECRET = "techdisciplesai404"
27
- MODEL_NAME = "meta-llama/Llama-3.1-8B"
 
28
  DEVICE = 0 if torch.cuda.is_available() else -1
29
 
30
  # =====================================================
@@ -39,39 +35,47 @@ logger = logging.getLogger("TechDisciplesAI")
39
  app = FastAPI(title="Tech Disciples AI", version="3.1")
40
 
41
  # =====================================================
42
- # MODEL LOAD
43
  # =====================================================
44
- llm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- try:
47
- logger.info(f"🚀 Loading model: {MODEL_NAME}")
 
 
48
 
49
- hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
50
- if hf_token:
51
  login(token=hf_token)
52
- logger.info("🔐 Hugging Face authentication successful.")
53
- else:
54
- logger.warning("⚠️ HUGGINGFACEHUB_API_TOKEN not found gated models may fail.")
55
-
56
- # Load text generation pipeline
57
- hf_pipeline = pipeline(
58
- "text-generation",
59
- model=MODEL_NAME,
60
- device=DEVICE,
61
- max_new_tokens=1024,
62
- temperature=0.4,
63
- top_p=0.9,
64
- repetition_penalty=1.15,
65
- do_sample=True,
66
- use_auth_token=True
67
- )
68
-
69
- llm = HuggingFacePipeline(pipeline=hf_pipeline)
70
- logger.info("✅ Model loaded successfully (Llama 3.1 - 8B).")
71
-
72
- except Exception as e:
73
- logger.error(f"❌ Model load failed: {e}")
74
- llm = None
75
 
76
  # =====================================================
77
  # MEMORY + PROMPT
@@ -96,10 +100,7 @@ prompt = PromptTemplate(
96
  input_variables=["conversation_history", "query"]
97
  )
98
 
99
- if llm:
100
- chain = LLMChain(prompt=prompt, llm=llm, memory=memory)
101
- else:
102
- chain = None
103
 
104
  # =====================================================
105
  # REQUEST MODEL
@@ -113,7 +114,7 @@ class QueryInput(BaseModel):
113
  # =====================================================
114
  @app.get("/")
115
  async def root():
116
- return {"message": "✅ Tech Disciples AI is running."}
117
 
118
  @app.post("/ai-chat")
119
  async def ai_chat(data: QueryInput, x_api_key: str = Header(None)):
@@ -121,11 +122,11 @@ async def ai_chat(data: QueryInput, x_api_key: str = Header(None)):
121
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key")
122
 
123
  if not chain:
124
- raise HTTPException(status_code=500, detail="Model not initialized or failed to load")
125
 
126
  try:
127
  response = chain.run(query=data.query.strip())
128
  return {"reply": response.strip()}
129
  except Exception as e:
130
- logger.error(f"⚠️ Error generating response: {e}")
131
- raise HTTPException(status_code=500, detail="Model failed to respond")
 
8
  import logging
9
  import os
10
 
 
11
  from huggingface_hub import login
 
 
12
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
13
  from langchain.chains import LLMChain
14
  from langchain.prompts.prompt import PromptTemplate
15
  from langchain.memory import ConversationBufferMemory
 
 
16
  from transformers import pipeline
17
 
18
  # =====================================================
19
  # CONFIGURATION
20
  # =====================================================
21
  API_SECRET = "techdisciplesai404"
22
+ PRIMARY_MODEL = "meta-llama/Llama-3.1-8B"
23
+ FALLBACK_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
24
  DEVICE = 0 if torch.cuda.is_available() else -1
25
 
26
  # =====================================================
 
35
  app = FastAPI(title="Tech Disciples AI", version="3.1")
36
 
37
  # =====================================================
38
+ # MODEL LOADING FUNCTION
39
  # =====================================================
40
+ def load_model(model_name, token=None):
41
+ try:
42
+ logger.info(f"🚀 Attempting to load model: {model_name}")
43
+ text_gen = pipeline(
44
+ "text-generation",
45
+ model=model_name,
46
+ device=DEVICE,
47
+ max_new_tokens=1024,
48
+ temperature=0.4,
49
+ top_p=0.9,
50
+ repetition_penalty=1.15,
51
+ do_sample=True,
52
+ token=token, # ✅ modern auth argument
53
+ )
54
+ logger.info(f"✅ Loaded model successfully: {model_name}")
55
+ return HuggingFacePipeline(pipeline=text_gen)
56
+ except Exception as e:
57
+ logger.error(f"❌ Failed to load {model_name}: {e}")
58
+ return None
59
 
60
+ # =====================================================
61
+ # LOAD TOKEN + MODEL
62
+ # =====================================================
63
+ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
64
 
65
+ if hf_token:
66
+ try:
67
  login(token=hf_token)
68
+ logger.info("🔐 Hugging Face token authenticated.")
69
+ except Exception as e:
70
+ logger.warning(f"⚠️ Failed to log in: {e}")
71
+ else:
72
+ logger.warning("⚠️ No HUGGINGFACEHUB_API_TOKEN found.")
73
+
74
+ llm = load_model(PRIMARY_MODEL, token=hf_token)
75
+
76
+ if llm is None:
77
+ logger.warning("⚠️ Falling back to Mistral 7B due to model load issue...")
78
+ llm = load_model(FALLBACK_MODEL, token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # =====================================================
81
  # MEMORY + PROMPT
 
100
  input_variables=["conversation_history", "query"]
101
  )
102
 
103
+ chain = LLMChain(prompt=prompt, llm=llm, memory=memory) if llm else None
 
 
 
104
 
105
  # =====================================================
106
  # REQUEST MODEL
 
114
  # =====================================================
115
  @app.get("/")
116
  async def root():
117
+ return {"message": "✅ Tech Disciples AI is online."}
118
 
119
  @app.post("/ai-chat")
120
  async def ai_chat(data: QueryInput, x_api_key: str = Header(None)):
 
122
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key")
123
 
124
  if not chain:
125
+ raise HTTPException(status_code=500, detail="Model not initialized")
126
 
127
  try:
128
  response = chain.run(query=data.query.strip())
129
  return {"reply": response.strip()}
130
  except Exception as e:
131
+ logger.error(f"⚠️ Model runtime error: {e}")
132
+ raise HTTPException(status_code=500, detail=f"Model failed to respond — {e}")