Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -8,6 +8,13 @@ from PyPDF2 import PdfReader
|
|
| 8 |
from fastapi import Depends
|
| 9 |
#在FastAPI中,Depends()函数用于声明依赖项
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
import random
|
| 12 |
import string
|
| 13 |
import sys
|
|
@@ -19,10 +26,44 @@ import os
|
|
| 19 |
from dotenv import load_dotenv
|
| 20 |
load_dotenv()
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def generate_random_string(length):
|
| 23 |
letters = string.ascii_lowercase
|
| 24 |
return ''.join(random.choice(letters) for i in range(length))
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
app = FastAPI()
|
| 27 |
|
| 28 |
class FileToProcess(BaseModel):
|
|
@@ -66,10 +107,48 @@ async def pdf_file_qa_process(username: str, request: Request, file_to_process:
|
|
| 66 |
text = page.extract_text()
|
| 67 |
if text:
|
| 68 |
raw_text += text
|
| 69 |
-
temp_texts = text_splitter.split_text(raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
api_call_msg={"INFO": f"File '{file_saved_in_api}' saved to your profile."}
|
| 75 |
print(api_call_msg)
|
|
|
|
| 8 |
from fastapi import Depends
|
| 9 |
#在FastAPI中,Depends()函数用于声明依赖项
|
| 10 |
|
| 11 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 12 |
+
from langchain import PromptTemplate, LLMChain
|
| 13 |
+
from langchain import HuggingFaceHub
|
| 14 |
+
from langchain.document_loaders import TextLoader
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
import requests
|
| 18 |
import random
|
| 19 |
import string
|
| 20 |
import sys
|
|
|
|
| 26 |
from dotenv import load_dotenv
|
| 27 |
load_dotenv()
|
| 28 |
|
| 29 |
+
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
| 30 |
+
model_id = os.getenv('model_id')
|
| 31 |
+
hf_token = os.getenv('hf_token')
|
| 32 |
+
repo_id = os.getenv('repo_id')
|
| 33 |
+
|
| 34 |
+
def get_embeddings(input_str_texts):
|
| 35 |
+
response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}})
|
| 36 |
+
return response.json()
|
| 37 |
+
|
| 38 |
def generate_random_string(length):
|
| 39 |
letters = string.ascii_lowercase
|
| 40 |
return ''.join(random.choice(letters) for i in range(length))
|
| 41 |
|
| 42 |
+
def remove_context(text):
|
| 43 |
+
if 'Context:' in text:
|
| 44 |
+
end_of_context = text.find('\n\n')
|
| 45 |
+
return text[end_of_context + 2:]
|
| 46 |
+
else:
|
| 47 |
+
return text
|
| 48 |
+
|
| 49 |
+
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
|
| 50 |
+
headers = {"Authorization": f"Bearer {hf_token}"}
|
| 51 |
+
|
| 52 |
+
llm = HuggingFaceHub(repo_id=repo_id,
|
| 53 |
+
model_kwargs={"min_length":100,
|
| 54 |
+
"max_new_tokens":1024, "do_sample":True,
|
| 55 |
+
"temperature":0.1,
|
| 56 |
+
"top_k":50,
|
| 57 |
+
"top_p":0.95, "eos_token_id":49155})
|
| 58 |
+
|
| 59 |
+
prompt_template = """
|
| 60 |
+
You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question {question}. If you don't know the answer, just say that you don't know. DON'T try to make up an answer.
|
| 61 |
+
Your response should be full and easy to understand.
|
| 62 |
+
"""
|
| 63 |
+
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
| 64 |
+
|
| 65 |
+
chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT)
|
| 66 |
+
|
| 67 |
app = FastAPI()
|
| 68 |
|
| 69 |
class FileToProcess(BaseModel):
|
|
|
|
| 107 |
text = page.extract_text()
|
| 108 |
if text:
|
| 109 |
raw_text += text
|
| 110 |
+
temp_texts = text_splitter.split_text(raw_text)
|
| 111 |
+
texts=temp_texts
|
| 112 |
+
initial_embeddings=get_embeddings(temp_texts)
|
| 113 |
+
db_embeddings = torch.FloatTensor(initial_embeddings)
|
| 114 |
+
print("db_embeddings created...")
|
| 115 |
+
|
| 116 |
+
#question = var_query.query
|
| 117 |
+
question = username
|
| 118 |
+
print("API Call Query Received: "+question)
|
| 119 |
+
q_embedding=get_embeddings(question)
|
| 120 |
+
final_q_embedding = torch.FloatTensor(q_embedding)
|
| 121 |
+
from sentence_transformers.util import semantic_search
|
| 122 |
+
hits = semantic_search(final_q_embedding, torch.FloatTensor(db_embeddings), top_k=5)
|
| 123 |
+
|
| 124 |
+
page_contents = []
|
| 125 |
+
for i in range(len(hits[0])):
|
| 126 |
+
page_content = texts[hits[0][i]['corpus_id']]
|
| 127 |
+
page_contents.append(page_content)
|
| 128 |
|
| 129 |
+
temp_page_contents=str(page_contents)
|
| 130 |
+
final_page_contents = temp_page_contents.replace('\\n', '')
|
| 131 |
+
random_string_2=generate_random_string(20)
|
| 132 |
+
file_path = random_string_2 + ".txt"
|
| 133 |
+
with open(file_path, "w", encoding="utf-8") as file:
|
| 134 |
+
file.write(final_page_contents)
|
| 135 |
+
|
| 136 |
+
loader = TextLoader(file_path, encoding="utf-8")
|
| 137 |
+
loaded_documents = loader.load()
|
| 138 |
+
|
| 139 |
+
temp_ai_response = chain({"input_documents": loaded_documents, "question": question}, return_only_outputs=False)
|
| 140 |
+
|
| 141 |
+
initial_ai_response=temp_ai_response['output_text']
|
| 142 |
+
|
| 143 |
+
cleaned_initial_ai_response = remove_context(initial_ai_response)
|
| 144 |
+
|
| 145 |
+
#final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
|
| 146 |
+
final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip()
|
| 147 |
+
final_ai_response = final_ai_response.partition('¿Cuáles')[0].strip()
|
| 148 |
+
final_ai_response = final_ai_response.partition('<|end|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
|
| 149 |
+
new_final_ai_response = final_ai_response.split('Unhelpful Answer:')[0].strip()
|
| 150 |
+
new_final_ai_response = new_final_ai_response.split('Note:')[0].strip()
|
| 151 |
+
new_final_ai_response = new_final_ai_response.split('Please provide feedback on how to improve the chatbot.')[0].strip()
|
| 152 |
|
| 153 |
api_call_msg={"INFO": f"File '{file_saved_in_api}' saved to your profile."}
|
| 154 |
print(api_call_msg)
|