Kalpokoch commited on
Commit
982da50
·
verified ·
1 Parent(s): 2bf4b72

Update app/app.py

Browse files
Files changed (1) hide show
  1. app/app.py +40 -69
app/app.py CHANGED
@@ -1,96 +1,67 @@
1
- from fastapi import FastAPI
2
- from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
 
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
- from app.policy_vector_db import PolicyVectorDB # Import your class
7
 
8
- # --- 1. Initialize the Vector Database and LLM ---
 
9
 
10
- # Load the vector database from /tmp (safest in Docker/HF Spaces)
11
  print("Loading Vector Database...")
12
  db = PolicyVectorDB(persist_directory="/tmp/policy_vector_db")
13
  print("Vector Database loaded successfully!")
14
 
15
  # Load your quantized model from Hugging Face Hub
16
- model_id = "Kalpokoch/QuantizedTinyLama" # Correct spelling assumed
17
  print(f"Loading model: {model_id}...")
 
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_id)
19
- model = AutoModelForCausalLM.from_pretrained(model_id)
20
- print("Model and tokenizer loaded successfully!")
21
 
22
- # Choose dtype depending on device support
23
- dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
 
 
 
 
 
24
 
 
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
- torch_dtype=dtype,
28
- device_map="auto"
29
  )
30
 
31
- # Create a text-generation pipeline
32
- pipe = pipeline(
33
- "text-generation",
34
- model=model,
35
- tokenizer=tokenizer,
36
- max_new_tokens=256
37
- )
38
 
39
- print("LLM and pipeline loaded successfully!")
40
 
 
 
 
41
 
42
- # --- 2. FastAPI App Setup ---
43
- app = FastAPI()
44
 
45
- app.add_middleware(
46
- CORSMiddleware,
47
- allow_origins=["*"],
48
- allow_credentials=True,
49
- allow_methods=["*"],
50
- allow_headers=["*"],
51
- )
52
 
 
 
 
53
 
54
- @app.get("/")
55
- def read_root():
56
- return {"message": "RAG chatbot backend is running with Kalpokoch/QuantizedTinyLlama and ChromaDB!"}
57
 
 
 
 
58
 
59
- class ChatRequest(BaseModel):
60
- question: str
61
 
 
 
62
 
63
- @app.post("/chat")
64
- def chat(request: ChatRequest):
65
- question = request.question.strip()
66
- if not question:
67
- return {"response": "Please ask a question."}
68
-
69
- # --- 3. RAG Retrieval using PolicyVectorDB ---
70
- print(f"Searching for context for question: '{question}'")
71
- search_results = db.search(query_text=question, top_k=3)
72
-
73
- if not search_results:
74
- retrieved_context = "No relevant context found."
75
- else:
76
- retrieved_context = "\n\n".join([result['text'] for result in search_results])
77
-
78
- print(f"Retrieved Context:\n{retrieved_context[:500]}...")
79
-
80
- # --- 4. Prompt Engineering and Generation ---
81
- prompt = (
82
- f"<|system|>\nYou are a helpful assistant for NEEPCO policies. "
83
- f"Use the following context to answer the user's question. If the context doesn't contain the answer, say that.\n"
84
- f"Context:\n{retrieved_context}</s>\n"
85
- f"<|user|>\n{question}</s>\n"
86
- f"<|assistant|>"
87
- )
88
-
89
- try:
90
- outputs = pipe(prompt)
91
- reply = outputs[0]['generated_text']
92
- assistant_reply = reply.split("<|assistant|>")[1].strip()
93
- return {"response": assistant_reply}
94
- except Exception as e:
95
- print(f"Error during model inference: {e}")
96
- return {"response": "Sorry, I encountered an error while generating a response."}
 
1
+ from fastapi import FastAPI, Request
 
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  import torch
5
+ from policy_vector_db import PolicyVectorDB # Make sure this is your local DB logic
6
+ import chromadb
7
 
8
+ # Create FastAPI app
9
+ app = FastAPI()
10
 
11
+ # Load the vector database from /tmp (safe for Hugging Face Spaces)
12
  print("Loading Vector Database...")
13
  db = PolicyVectorDB(persist_directory="/tmp/policy_vector_db")
14
  print("Vector Database loaded successfully!")
15
 
16
  # Load your quantized model from Hugging Face Hub
17
+ model_id = "Kalpokoch/QuantizedTinyLama"
18
  print(f"Loading model: {model_id}...")
19
+
20
+ # Load tokenizer
21
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
22
 
23
+ # Quantization config for bitsandbytes
24
+ bnb_config = BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_use_double_quant=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_compute_dtype=torch.bfloat16
29
+ )
30
 
31
+ # Load quantized model
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_id,
34
+ device_map="auto",
35
+ quantization_config=bnb_config
36
  )
37
 
38
+ print("Model and tokenizer loaded successfully!")
 
 
 
 
 
 
39
 
 
40
 
41
+ # Input schema
42
+ class Query(BaseModel):
43
+ question: str
44
 
 
 
45
 
46
+ # Define endpoint
47
+ @app.post("/chat/")
48
+ async def chat(query: Query):
49
+ question = query.question
 
 
 
50
 
51
+ # Step 1: Vector DB search
52
+ search_results = db.search(question)
53
+ context = "\n".join([res["content"] for res in search_results])
54
 
55
+ # Step 2: Build prompt
56
+ prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
 
57
 
58
+ # Step 3: Tokenize and generate
59
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
60
+ outputs = model.generate(**inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
61
 
62
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
63
 
64
+ # Optionally strip out the prompt from the output
65
+ final_answer = answer.split("Answer:")[-1].strip()
66
 
67
+ return {"answer": final_answer}