Shriyakupp commited on
Commit
78b6072
·
verified ·
1 Parent(s): 66c419a

Create api.py

Browse files
Files changed (1) hide show
  1. api.py +152 -0
api.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uvicorn
3
+ import requests
4
+ import json
5
+ import numpy as np
6
+ import faiss
7
+ from dotenv import load_dotenv
8
+ from collections import defaultdict
9
+ from fastapi import FastAPI, HTTPException, Request
10
+ from pydantic import BaseModel
11
+ from langchain_nomic import NomicEmbeddings # ✅ Updated Embedding Model
12
+
13
+ # Initialize FastAPI
14
+ app = FastAPI()
15
+
16
+ # --- Load Environment Variables ---
17
+ load_dotenv()
18
+ api_key = os.getenv("AIPIPE_API_KEY")
19
+
20
+ if not api_key:
21
+ raise RuntimeError("Missing API key in environment variables.")
22
+
23
+ # --- Load Discourse Data ---
24
+ try:
25
+ with open("data/discourse_posts.json", "r", encoding="utf-8") as f:
26
+ posts_data = json.load(f)
27
+ except FileNotFoundError:
28
+ raise RuntimeError("Could not find 'data/discourse_posts.json'. Ensure the file is in the correct location.")
29
+
30
+ # Group posts by topic
31
+ topics = defaultdict(lambda: {"topic_title": "", "posts": []})
32
+ for post in posts_data:
33
+ tid = post["topic_id"]
34
+ topics[tid]["posts"].append(post)
35
+ if "topic_title" in post:
36
+ topics[tid]["topic_title"] = post["topic_title"]
37
+
38
+ # Sort posts within topics by post_number
39
+ for topic in topics.values():
40
+ topic["posts"].sort(key=lambda x: x.get("post_number", 0))
41
+
42
+ # --- Embedding Setup ---
43
+ def normalize(v):
44
+ norm = np.linalg.norm(v)
45
+ return v / norm if norm != 0 else v
46
+
47
+ embedder = NomicEmbeddings(model="nomic-embed-text") # ✅ Updated Embedding Model
48
+ embedding_data = []
49
+ embeddings = []
50
+
51
+ # Process topics for FAISS
52
+ for tid, data in topics.items():
53
+ posts = data["posts"]
54
+ title = data["topic_title"]
55
+ reply_map = defaultdict(list)
56
+ by_number = {}
57
+
58
+ for p in posts:
59
+ pn = p.get("post_number")
60
+ if pn is not None:
61
+ by_number[pn] = p
62
+ parent = p.get("reply_to_post_number")
63
+ reply_map[parent].append(p)
64
+
65
+ def extract(pn):
66
+ collected = []
67
+ def dfs(n):
68
+ if n not in by_number:
69
+ return
70
+ p = by_number[n]
71
+ collected.append(p)
72
+ for child in reply_map.get(n, []):
73
+ dfs(child.get("post_number"))
74
+ dfs(pn)
75
+ return collected
76
+
77
+ roots = [p for p in posts if not p.get("reply_to_post_number")]
78
+ for root in roots:
79
+ root_num = root.get("post_number", 1)
80
+ thread = extract(root_num)
81
+ text = f"Topic: {title}\n\n" + "\n\n---\n\n".join(
82
+ p.get("content", "").strip() for p in thread if p.get("content")
83
+ )
84
+ emb = normalize(embedder.embed_query(text)) # ✅ Updated Embedding Call
85
+ embedding_data.append({
86
+ "topic_id": tid,
87
+ "topic_title": title,
88
+ "root_post_number": root_num,
89
+ "post_numbers": [p.get("post_number") for p in thread],
90
+ "combined_text": text
91
+ })
92
+ embeddings.append(emb)
93
+
94
+ # Create FAISS index
95
+ index = faiss.IndexFlatIP(len(embeddings[0]))
96
+ index.add(np.vstack(embeddings).astype("float32"))
97
+
98
+ # --- API Input Model ---
99
+ class QuestionInput(BaseModel):
100
+ question: str
101
+ image: str = None # Optional image input, unused here
102
+
103
+ # --- AIPIPE API Configuration ---
104
+ AIPIPE_URL = "https://your-aipipe-endpoint.com/chat/completions"
105
+ AIPIPE_KEY = api_key
106
+
107
+ def query_aipipe(prompt):
108
+ headers = {"Authorization": f"Bearer {AIPIPE_KEY}", "Content-Type": "application/json"}
109
+ data = {"model": "gpt-4o-mini", "messages": [{"role": "user", "content": prompt}], "temperature": 0.7}
110
+
111
+ response = requests.post(AIPIPE_URL, json=data, headers=headers)
112
+ if response.status_code == 200:
113
+ return response.json()
114
+ else:
115
+ raise HTTPException(status_code=500, detail=f"AIPIPE API error: {response.text}")
116
+
117
+ # --- API Endpoint for Answer Generation ---
118
+ @app.post("/api/")
119
+ async def answer_question(payload: QuestionInput):
120
+ q = payload.question
121
+
122
+ # Ensure query is valid
123
+ if not q:
124
+ raise HTTPException(status_code=400, detail="Question field cannot be empty.")
125
+
126
+ # Search FAISS Index
127
+ q_emb = normalize(embedder.embed_query(q)).astype("float32") # ✅ Updated Query Embedding Call
128
+ D, I = index.search(np.array([q_emb]), 3)
129
+
130
+ top_results = []
131
+ for score, idx in zip(D[0], I[0]):
132
+ data = embedding_data[idx]
133
+ top_results.append({
134
+ "score": float(score),
135
+ "text": data["combined_text"],
136
+ "topic_id": data["topic_id"],
137
+ "url": f"https://discourse.onlinedegree.iitm.ac.in/t/{data['topic_id']}"
138
+ })
139
+
140
+ # Generate answer using AIPIPE
141
+ try:
142
+ answer_response = query_aipipe(q)
143
+ answer = answer_response.get("choices", [{}])[0].get("message", {}).get("content", "No response.")
144
+ except Exception as e:
145
+ raise HTTPException(status_code=500, detail=f"Error fetching response from AIPIPE: {str(e)}")
146
+
147
+ links = [{"url": r["url"], "text": r["text"][:120]} for r in top_results]
148
+ return {"answer": answer, "links": links}
149
+
150
+ # --- Run the Server ---
151
+ if __name__ == "__main__":
152
+ uvicorn.run("api:app", reload=True)