william4416 commited on
Commit
8ffb2d8
·
verified ·
1 Parent(s): 9c2948a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -44
app.py CHANGED
@@ -7,65 +7,64 @@ import json
7
  # Create FastAPI app instance
8
  app = FastAPI()
9
 
10
- # Load DialoGPT model and tokenizer
11
- try:
12
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
13
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
14
- except Exception as e:
15
- raise HTTPException(status_code=500, detail=f"Model loading failed: {e}")
 
 
16
 
17
- # Load courses data
18
- try:
19
- with open("uts_courses.json", "r") as file:
20
- courses_data = json.load(file)
21
- except Exception as e:
22
- raise HTTPException(status_code=500, detail=f"Courses data loading failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Define user input model
25
  class UserInput(BaseModel):
26
  user_input: str
27
 
28
- # Generate response function
29
- def generate_response(user_input: str):
30
- """
31
- Generate response based on user input
32
-
33
- Args:
34
- user_input: User input text
35
-
36
- Returns:
37
- Generated response text
38
- """
39
- if user_input.lower() == "help":
40
- return "I can help you with UTS courses information, feel free to ask!"
41
- elif user_input.lower() == "exit":
42
- return "Goodbye!"
43
- elif user_input.lower() == "list courses":
44
- # Generate course list
45
- course_list = "\n".join([f"{category}: {', '.join(courses)}" for category, courses in courses_data["courses"].items()])
46
- return f"Here are the available courses:\n{course_list}"
47
- elif user_input.lower() in courses_data["courses"]:
48
- # List courses under the specified course category
49
- return f"The courses in {user_input} category are: {', '.join(courses_data['courses'][user_input])}"
50
- else:
51
- # Use DialoGPT model to generate response
52
- input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
53
- response_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
54
- response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
55
- return response
56
 
57
  # Define API route
58
  @app.post("/")
59
  async def chat(user_input: UserInput):
60
  """
61
  Process user input and return response
62
-
63
  Args:
64
  user_input: User input JSON data
65
-
66
  Returns:
67
  JSON data containing the response text
68
  """
69
- response = generate_response(user_input.user_input)
70
  return {"response": response}
71
-
 
7
  # Create FastAPI app instance
8
  app = FastAPI()
9
 
10
+ class ChatBot:
11
+ def __init__(self):
12
+ # Load DialoGPT model and tokenizer
13
+ try:
14
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
15
+ self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
16
+ except Exception as e:
17
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {e}")
18
 
19
+ # Load courses data
20
+ try:
21
+ with open("uts_courses.json", "r") as file:
22
+ self.courses_data = json.load(file)
23
+ except Exception as e:
24
+ raise HTTPException(status_code=500, detail=f"Courses data loading failed: {e}")
25
+
26
+ def generate_response(self, user_input: str):
27
+ """
28
+ Generate response based on user input
29
+ Args:
30
+ user_input: User input text
31
+ Returns:
32
+ Generated response text
33
+ """
34
+ if user_input.lower() == "help":
35
+ return "I can help you with UTS courses information, feel free to ask!"
36
+ elif user_input.lower() == "exit":
37
+ return "Goodbye!"
38
+ elif user_input.lower() == "list courses":
39
+ # Generate course list
40
+ course_list = "\n".join([f"{category}: {', '.join(courses)}" for category, courses in self.courses_data["courses"].items()])
41
+ return f"Here are the available courses:\n{course_list}"
42
+ elif user_input.lower() in self.courses_data["courses"]:
43
+ # List courses under the specified course category
44
+ return f"The courses in {user_input} category are: {', '.join(self.courses_data['courses'][user_input])}"
45
+ else:
46
+ # Use DialoGPT model to generate response
47
+ input_ids = self.tokenizer.encode(user_input + self.tokenizer.eos_token, return_tensors="pt")
48
+ response_ids = self.model.generate(input_ids, max_length=100, pad_token_id=self.tokenizer.eos_token_id)
49
+ response = self.tokenizer.decode(response_ids[0], skip_special_tokens=True)
50
+ return response
51
 
52
  # Define user input model
53
  class UserInput(BaseModel):
54
  user_input: str
55
 
56
+ # Create chatbot instance
57
+ chatbot = ChatBot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Define API route
60
  @app.post("/")
61
  async def chat(user_input: UserInput):
62
  """
63
  Process user input and return response
 
64
  Args:
65
  user_input: User input JSON data
 
66
  Returns:
67
  JSON data containing the response text
68
  """
69
+ response = chatbot.generate_response(user_input.user_input)
70
  return {"response": response}