Spaces:
Paused
Paused
Commit
·
47ab7c3
1
Parent(s):
cde87a2
Updates
Browse files- Dockerfile +4 -0
- app.py +10 -4
- requirements.txt +3 -2
Dockerfile
CHANGED
|
@@ -5,6 +5,10 @@ RUN apt-get update && apt-get install -y \
|
|
| 5 |
ffmpeg git build-essential python3-dev && \
|
| 6 |
rm -rf /var/lib/apt/lists/*
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# Set working directory
|
| 9 |
WORKDIR /app/chatbot
|
| 10 |
|
|
|
|
| 5 |
ffmpeg git build-essential python3-dev && \
|
| 6 |
rm -rf /var/lib/apt/lists/*
|
| 7 |
|
| 8 |
+
# Install CPU-specific PyTorch (wheel from PyTorch index)
|
| 9 |
+
RUN pip install torch==2.1.1+cpu torchvision==0.16.1+cpu torchaudio==2.1.1+cpu \
|
| 10 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
| 11 |
+
|
| 12 |
# Set working directory
|
| 13 |
WORKDIR /app/chatbot
|
| 14 |
|
app.py
CHANGED
|
@@ -5,11 +5,12 @@ import torch
|
|
| 5 |
import soundfile as sf
|
| 6 |
from flask import Flask, request, jsonify, send_from_directory
|
| 7 |
from flask_cors import CORS
|
| 8 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
from deep_translator import GoogleTranslator
|
| 10 |
from textblob import TextBlob
|
| 11 |
import nltk
|
| 12 |
from parler_tts import ParlerTTSForConditionalGeneration
|
|
|
|
| 13 |
|
| 14 |
# Flask setup
|
| 15 |
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
@@ -32,16 +33,19 @@ except LookupError:
|
|
| 32 |
nltk.download('punkt')
|
| 33 |
nltk.download('punkt_tab')
|
| 34 |
|
|
|
|
| 35 |
|
| 36 |
class ChatBot:
|
| 37 |
def __init__(self):
|
| 38 |
self.chat_history_ids = None
|
| 39 |
self.bot_input_ids = None
|
| 40 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 41 |
-
self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
| 42 |
-
self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
|
|
|
| 43 |
# Parler-TTS Setup
|
| 44 |
-
self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay").to(
|
|
|
|
| 45 |
self.tts_tokenizer = AutoTokenizer.from_pretrained("doublesizebed/parler-tts-mini-malay")
|
| 46 |
self.description_tokenizer = AutoTokenizer.from_pretrained(self.tts_model.config.text_encoder._name_or_path)
|
| 47 |
|
|
@@ -71,6 +75,7 @@ class ChatBot:
|
|
| 71 |
else:
|
| 72 |
self.chat_history_ids = torch.cat([self.chat_history_ids, prompt_ids], dim=-1)
|
| 73 |
|
|
|
|
| 74 |
output = self.model.generate(
|
| 75 |
self.chat_history_ids,
|
| 76 |
max_length=self.chat_history_ids.shape[-1] + 128,
|
|
@@ -127,6 +132,7 @@ class ChatBot:
|
|
| 127 |
desc_inputs = self.description_tokenizer(description, return_tensors="pt", padding=True).to(self.device)
|
| 128 |
text_inputs = self.tts_tokenizer(text, return_tensors="pt", padding=True).to(self.device)
|
| 129 |
|
|
|
|
| 130 |
generation = self.tts_model.generate(
|
| 131 |
input_ids=desc_inputs.input_ids,
|
| 132 |
attention_mask=desc_inputs.attention_mask,
|
|
|
|
| 5 |
import soundfile as sf
|
| 6 |
from flask import Flask, request, jsonify, send_from_directory
|
| 7 |
from flask_cors import CORS
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 9 |
from deep_translator import GoogleTranslator
|
| 10 |
from textblob import TextBlob
|
| 11 |
import nltk
|
| 12 |
from parler_tts import ParlerTTSForConditionalGeneration
|
| 13 |
+
from torch.quantization import quantize_dynamic
|
| 14 |
|
| 15 |
# Flask setup
|
| 16 |
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
| 33 |
nltk.download('punkt')
|
| 34 |
nltk.download('punkt_tab')
|
| 35 |
|
| 36 |
+
bnb_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
|
| 37 |
|
| 38 |
class ChatBot:
|
| 39 |
def __init__(self):
|
| 40 |
self.chat_history_ids = None
|
| 41 |
self.bot_input_ids = None
|
| 42 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 43 |
+
self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_fast=False)
|
| 44 |
+
self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", quantization_config=bnb_config, device_map="cpu")
|
| 45 |
+
self.model = quantize_dynamic(self.model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 46 |
# Parler-TTS Setup
|
| 47 |
+
self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay").to("cpu")
|
| 48 |
+
self.tts_model = quantize_dynamic(self.tts_model, {torch.nn.Linear}, dtype=torch.qint8)
|
| 49 |
self.tts_tokenizer = AutoTokenizer.from_pretrained("doublesizebed/parler-tts-mini-malay")
|
| 50 |
self.description_tokenizer = AutoTokenizer.from_pretrained(self.tts_model.config.text_encoder._name_or_path)
|
| 51 |
|
|
|
|
| 75 |
else:
|
| 76 |
self.chat_history_ids = torch.cat([self.chat_history_ids, prompt_ids], dim=-1)
|
| 77 |
|
| 78 |
+
self.model.eval()
|
| 79 |
output = self.model.generate(
|
| 80 |
self.chat_history_ids,
|
| 81 |
max_length=self.chat_history_ids.shape[-1] + 128,
|
|
|
|
| 132 |
desc_inputs = self.description_tokenizer(description, return_tensors="pt", padding=True).to(self.device)
|
| 133 |
text_inputs = self.tts_tokenizer(text, return_tensors="pt", padding=True).to(self.device)
|
| 134 |
|
| 135 |
+
self.tts_model.eval()
|
| 136 |
generation = self.tts_model.generate(
|
| 137 |
input_ids=desc_inputs.input_ids,
|
| 138 |
attention_mask=desc_inputs.attention_mask,
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@ flask
|
|
| 2 |
Cython
|
| 3 |
flask[async]
|
| 4 |
waitress
|
| 5 |
-
torch
|
| 6 |
transformers
|
| 7 |
deep-translator
|
| 8 |
nest_asyncio
|
|
@@ -11,4 +11,5 @@ soundfile
|
|
| 11 |
textblob
|
| 12 |
malaya
|
| 13 |
parler_tts
|
| 14 |
-
nltk
|
|
|
|
|
|
| 2 |
Cython
|
| 3 |
flask[async]
|
| 4 |
waitress
|
| 5 |
+
torch==2.1.1+cpu
|
| 6 |
transformers
|
| 7 |
deep-translator
|
| 8 |
nest_asyncio
|
|
|
|
| 11 |
textblob
|
| 12 |
malaya
|
| 13 |
parler_tts
|
| 14 |
+
nltk
|
| 15 |
+
bitsandbytes
|