jpatel20 commited on
Commit
9e8cb4f
·
verified ·
1 Parent(s): e227a15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -72
app.py CHANGED
@@ -1,81 +1,32 @@
1
- import os
2
- from flask import Flask
3
- from flask_cors import CORS
4
- from flask_limiter import Limiter
5
- from flask_limiter.util import get_remote_address
6
- from dotenv import load_dotenv
7
- import logging
8
- from utils.db import init_db
9
- from config.config import Config
10
- from flask import render_template
11
  import gradio as gr
12
- from transformers import AutoTokenizer, AutoModelForCausalLM
13
-
14
- # Load environment variables
15
- load_dotenv()
16
-
17
- # Logging setup
18
- logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
19
-
20
- app = Flask(__name__)
21
- app.config.from_object(Config)
22
- CORS(app)
23
-
24
- # Rate limiter
25
- limiter = Limiter(
26
- get_remote_address,
27
- app=app,
28
- default_limits=["5 per minute"]
29
- )
30
-
31
- # Initialize DB
32
- init_db()
33
-
34
- # Register blueprints
35
- from routes.chat import chat_bp
36
- from routes.feedback import feedback_bp
37
- app.register_blueprint(chat_bp)
38
- app.register_blueprint(feedback_bp)
39
-
40
- @app.route("/")
41
- def home():
42
- api_key = Config.HF_API_TOKEN # Get the API key from your config
43
- return render_template("index.html", api_key=api_key)
44
-
45
- @app.errorhandler(429)
46
- def ratelimit_handler(e):
47
- return {"error": "Rate limit exceeded. Please try again later."}, 429
48
-
49
- # Load model (will be cached on Hugging Face servers)
50
- tokenizer = AutoTokenizer.from_pretrained("bitext/Mistral-7B-Customer-Support")
51
- model = AutoModelForCausalLM.from_pretrained("bitext/Mistral-7B-Customer-Support")
52
 
53
  def chat(message, history):
54
- # Format conversation history
55
- messages = []
56
- for human, assistant in history:
57
- messages.append({"role": "user", "content": human})
58
- messages.append({"role": "assistant", "content": assistant})
59
- messages.append({"role": "user", "content": message})
60
 
61
- # Generate response
62
- inputs = tokenizer.apply_chat_template(
63
- messages,
64
- add_generation_prompt=True,
65
- tokenize=True,
66
- return_dict=True,
67
- return_tensors="pt",
68
- )
69
 
70
- outputs = model.generate(**inputs, max_new_tokens=150, temperature=0.7)
71
- response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
72
- return response.strip()
 
 
 
 
 
 
 
 
 
73
 
74
  # Create Gradio interface
75
  demo = gr.ChatInterface(
76
  fn=chat,
77
  title="AI Customer Service Chatbot",
78
- description="Powered by Mistral-7B Customer Support model",
79
  examples=[
80
  ["How can I reset my password?"],
81
  ["What are your return policies?"],
@@ -84,7 +35,4 @@ demo = gr.ChatInterface(
84
  ]
85
  )
86
 
87
- demo.launch()
88
-
89
- if __name__ == "__main__":
90
- app.run(debug=True)
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import requests
3
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def chat(message, history):
6
+ API_URL = "https://api-inference.huggingface.co/models/bitext/Mistral-7B-Customer-Support"
7
+ headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
 
 
 
 
8
 
9
+ # Format the message for the API
10
+ payload = {"inputs": message}
 
 
 
 
 
 
11
 
12
+ try:
13
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
14
+ if response.status_code == 200:
15
+ result = response.json()
16
+ if isinstance(result, list) and len(result) > 0:
17
+ return result[0]["generated_text"]
18
+ else:
19
+ return "Sorry, I couldn't generate a response."
20
+ else:
21
+ return f"API Error: {response.status_code}"
22
+ except Exception as e:
23
+ return f"Error: {str(e)}"
24
 
25
  # Create Gradio interface
26
  demo = gr.ChatInterface(
27
  fn=chat,
28
  title="AI Customer Service Chatbot",
29
+ description="Powered by Mistral-7B Customer Support (via Hugging Face API)",
30
  examples=[
31
  ["How can I reset my password?"],
32
  ["What are your return policies?"],
 
35
  ]
36
  )
37
 
38
+ demo.launch()