Sandei commited on
Commit
81b1a96
·
1 Parent(s): c4e54a2

Deploy FastAPI app

Browse files
Files changed (9) hide show
  1. .env +1 -0
  2. Dockerfile +20 -0
  3. __pycache__/main.cpython-314.pyc +0 -0
  4. app.py +107 -0
  5. final_data_set(in).csv +0 -0
  6. memeory.py +17 -0
  7. models.py +31 -0
  8. rag.py +18 -0
  9. requirements.txt +12 -0
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ GEMINI_API_KEY=
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . /app
6
+
7
+ RUN pip install --upgrade pip
8
+
9
+ # Install CPU-only PyTorch
10
+ RUN pip install --no-cache-dir \
11
+ torch==2.1.2+cpu \
12
+ torchvision==0.16.2+cpu \
13
+ torchaudio==2.1.2+cpu \
14
+ --index-url https://download.pytorch.org/whl/cpu
15
+
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ EXPOSE 7860
19
+
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
__pycache__/main.cpython-314.pyc ADDED
Binary file (4.43 kB). View file
 
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
4
+
5
+ from models import (
6
+ QueryRequest,
7
+ QueryResponse,
8
+ CategoryPrediction,
9
+ UrgencyPrediction
10
+ )
11
+ from rag import generate_answer
12
+ from memory import get_conversation, add_message
13
+
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ CLASSIFIER_MODEL_ID = "your-org/your-multitask-model"
17
+
18
+ tag_classes = [
19
+ "Billing",
20
+ "Network & Connectivity",
21
+ "Account Access",
22
+ "Hardware",
23
+ "Other"
24
+ ]
25
+
26
+ urgency_encoder = {
27
+ 0: "low",
28
+ 1: "medium",
29
+ 2: "high"
30
+ }
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_ID)
33
+ config = AutoConfig.from_pretrained(CLASSIFIER_MODEL_ID)
34
+
35
+ model = AutoModelForSequenceClassification.from_pretrained(
36
+ CLASSIFIER_MODEL_ID,
37
+ config=config,
38
+ trust_remote_code=True
39
+ ).to(DEVICE)
40
+
41
+ model.eval()
42
+
43
+ app = FastAPI(title="RAG + Conversation Memory API")
44
+
45
+ # ---------------------
46
+ # CLASSIFIER
47
+ # ---------------------
48
+ def classify_text(text: str, threshold: float = 0.5):
49
+ inputs = tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
50
+
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
53
+
54
+ category_probs = torch.sigmoid(outputs.category_logits)[0].cpu().numpy()
55
+
56
+ categories = [
57
+ CategoryPrediction(
58
+ category=tag_classes[i],
59
+ confidence=float(category_probs[i])
60
+ )
61
+ for i in range(len(tag_classes))
62
+ if category_probs[i] >= threshold
63
+ ]
64
+
65
+ urgency_probs = torch.softmax(outputs.urgency_logits, dim=-1)[0].cpu().numpy()
66
+ urgency_idx = int(torch.argmax(outputs.urgency_logits, dim=-1)[0])
67
+
68
+ urgency = UrgencyPrediction(
69
+ label=urgency_encoder[urgency_idx],
70
+ confidence=float(urgency_probs[urgency_idx])
71
+ )
72
+
73
+ return categories, urgency
74
+
75
+
76
+ def retrieve_documents(query: str):
77
+ return [
78
+ "Restarting the router fixes most connectivity issues.",
79
+ "Check for planned ISP maintenance.",
80
+ "Verify cables are securely connected."
81
+ ]
82
+
83
+
84
+ @app.post("/query", response_model=QueryResponse)
85
+ def query_endpoint(req: QueryRequest):
86
+ # ---- Load conversation
87
+ history = get_conversation(req.user_id)
88
+
89
+ # ---- Classification
90
+ categories, urgency = classify_text(req.query)
91
+
92
+ # ---- RAG
93
+ docs = retrieve_documents(req.query)
94
+ answer = generate_answer(req.query, docs, history)
95
+
96
+ # ---- Update memory
97
+ add_message(req.user_id, "user", req.query)
98
+ add_message(req.user_id, "assistant", answer)
99
+
100
+ return QueryResponse(
101
+ user_id=req.user_id,
102
+ query=req.query,
103
+ answer=answer,
104
+ categories=categories,
105
+ urgency=urgency,
106
+ conversation=get_conversation(req.user_id)
107
+ )
final_data_set(in).csv ADDED
The diff for this file is too large to render. See raw diff
 
memeory.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, deque
2
+
3
+ MAX_TURNS = 6 # last N messages per user
4
+
5
+ conversation_store = defaultdict(
6
+ lambda: deque(maxlen=MAX_TURNS)
7
+ )
8
+
9
+
10
+ def get_conversation(user_id: str):
11
+ return list(conversation_store[user_id])
12
+
13
+
14
+ def add_message(user_id: str, role: str, content: str):
15
+ conversation_store[user_id].append(
16
+ {"role": role, "content": content}
17
+ )
models.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List
3
+
4
+
5
+ class QueryRequest(BaseModel):
6
+ user_id: str
7
+ query: str
8
+
9
+
10
+ class CategoryPrediction(BaseModel):
11
+ category: str
12
+ confidence: float
13
+
14
+
15
+ class UrgencyPrediction(BaseModel):
16
+ label: str
17
+ confidence: float
18
+
19
+
20
+ class Message(BaseModel):
21
+ role: str
22
+ content: str
23
+
24
+
25
+ class QueryResponse(BaseModel):
26
+ user_id: str
27
+ query: str
28
+ answer: str
29
+ categories: List[CategoryPrediction]
30
+ urgency: UrgencyPrediction
31
+ conversation: List[Message]
rag.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def generate_answer(query: str, retrieved_docs: list[str], history: list[dict]) -> str:
2
+ history_text = "\n".join(
3
+ f"{m['role']}: {m['content']}" for m in history
4
+ )
5
+
6
+ context = "\n".join(retrieved_docs[:3])
7
+
8
+ return f"""
9
+ Conversation so far:
10
+ {history_text}
11
+
12
+ Knowledge base:
13
+ {context}
14
+
15
+ Answer:
16
+ We have received your request regarding "{query}".
17
+ Our support team will assist you shortly.
18
+ """.strip()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.103,<1
2
+ uvicorn[standard]
3
+
4
+ torch
5
+ transformers>=4.36,<5
6
+ sentence-transformers>=2.2,<3
7
+ huggingface-hub>=0.20,<1
8
+ accelerate
9
+
10
+ faiss-cpu
11
+ pandas
12
+ python-dotenv