shreyankisiri commited on
Commit
5cd49b4
·
verified ·
1 Parent(s): 3106607

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +147 -0
main.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import chromadb
4
+ import numpy as np
5
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
6
+ from pydantic import BaseModel
7
+ from typing import List, Optional
8
+ from huggingface_hub import InferenceClient
9
+ from scipy.spatial.distance import cosine
10
+
11
+ app = FastAPI(title="Course Recommendation API")
12
+
13
+ # Initialize Hugging Face Inference Client
14
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN", os.getenv['HF_API_TOKEN'])
15
+ client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2", token=HF_API_TOKEN)
16
+
17
+ # Initialize ChromaDB
18
+ chroma_client = chromadb.PersistentClient(path="./chroma_db")
19
+ collection = chroma_client.get_or_create_collection(name="courses")
20
+
21
+ def get_embedding(text):
22
+ response = client.post(json={"inputs": text}, task="feature-extraction")
23
+
24
+ # Handle different response formats
25
+ if hasattr(response, 'tolist'):
26
+ return response.tolist() # Handle if it's already a NumPy array
27
+ elif isinstance(response, list):
28
+ if len(response) > 0 and isinstance(response[0], list):
29
+ return response[0] # Return first item if response is a list of lists
30
+ else:
31
+ return response # Return as is if it's a flat list
32
+ else:
33
+ # Convert from bytes if needed
34
+ try:
35
+ if isinstance(response, bytes):
36
+ import ast
37
+ return ast.literal_eval(response.decode('utf-8'))
38
+ else:
39
+ return response
40
+ except:
41
+ raise ValueError(f"Unexpected embedding format: {response}")
42
+
43
+ class Course(BaseModel):
44
+ course_id: str
45
+ course_name: str
46
+ abstract: str
47
+
48
+ class CourseResponse(BaseModel):
49
+ course_id: str
50
+ name: str
51
+ similarity: float
52
+
53
+ @app.post("/add_course")
54
+ async def add_course(course: Course):
55
+ """Add a single course to the database"""
56
+ text = f"Course: {course.course_name}, Description: {course.abstract}"
57
+
58
+ try:
59
+ embedding = get_embedding(text)
60
+ if not isinstance(embedding, list):
61
+ if hasattr(embedding, 'tolist'):
62
+ embedding = embedding.tolist()
63
+ else:
64
+ embedding = list(embedding)
65
+
66
+ collection.add(
67
+ ids=[course.course_id],
68
+ embeddings=[embedding],
69
+ metadatas=[{"course_id": course.course_id, "name": course.course_name}]
70
+ )
71
+ return {"status": "success", "message": "Course added successfully"}
72
+ except Exception as e:
73
+ raise HTTPException(status_code=500, detail=f"Error adding course: {str(e)}")
74
+
75
+ @app.post("/upload_courses")
76
+ async def upload_courses(file: UploadFile = File(...)):
77
+ """Upload a JSON file with multiple courses"""
78
+ try:
79
+ contents = await file.read()
80
+ courses = json.loads(contents)
81
+
82
+ for course in courses:
83
+ text = f"Course: {course['course_name']}, Description: {course['abstract']}"
84
+ embedding = get_embedding(text)
85
+
86
+ if not isinstance(embedding, list):
87
+ if hasattr(embedding, 'tolist'):
88
+ embedding = embedding.tolist()
89
+ else:
90
+ embedding = list(embedding)
91
+
92
+ collection.add(
93
+ ids=[str(course["course_id"])],
94
+ embeddings=[embedding],
95
+ metadatas=[{"course_id": course["course_id"], "name": course["course_name"]}]
96
+ )
97
+
98
+ return {"status": "success", "message": f"{len(courses)} courses added successfully"}
99
+ except Exception as e:
100
+ raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
101
+
102
+ @app.get("/search", response_model=List[CourseResponse])
103
+ async def search_courses(query: str, limit: Optional[int] = 3):
104
+ """Find similar courses based on query text"""
105
+ try:
106
+ query_embedding = get_embedding(query)
107
+
108
+ # Ensure query embedding is properly formatted
109
+ if not isinstance(query_embedding, (list, np.ndarray)):
110
+ if hasattr(query_embedding, 'tolist'):
111
+ query_embedding = query_embedding.tolist()
112
+ else:
113
+ query_embedding = list(query_embedding)
114
+
115
+ # Retrieve stored embeddings
116
+ results = collection.get(include=["embeddings", "metadatas"])
117
+ courses = results["metadatas"]
118
+ stored_embeddings = results["embeddings"]
119
+
120
+ if not courses:
121
+ return []
122
+
123
+ # Compute cosine similarities
124
+ similarities = [1 - cosine(query_embedding, emb) for emb in stored_embeddings]
125
+
126
+ # Get top similar courses
127
+ top_indices = np.argsort(similarities)[-limit:][::-1]
128
+
129
+ # Format response
130
+ response = []
131
+ for i in top_indices:
132
+ response.append(
133
+ CourseResponse(
134
+ course_id=courses[i]["course_id"],
135
+ name=courses[i]["name"],
136
+ similarity=float(similarities[i])
137
+ )
138
+ )
139
+
140
+ return response
141
+ except Exception as e:
142
+ raise HTTPException(status_code=500, detail=f"Error searching courses: {str(e)}")
143
+
144
+ @app.get("/health")
145
+ async def health_check():
146
+ """Health check endpoint"""
147
+ return {"status": "ok"}