Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -235,7 +235,7 @@ import time
|
|
| 235 |
from fastapi import FastAPI, Request
|
| 236 |
from fastapi.responses import HTMLResponse
|
| 237 |
from fastapi.staticfiles import StaticFiles
|
| 238 |
-
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader,
|
| 239 |
from llama_index.llms.huggingface import HuggingFaceLLM
|
| 240 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 241 |
from pydantic import BaseModel
|
|
@@ -244,7 +244,7 @@ import datetime
|
|
| 244 |
from fastapi.middleware.cors import CORSMiddleware
|
| 245 |
from fastapi.templating import Jinja2Templates
|
| 246 |
from simple_salesforce import Salesforce, SalesforceLogin
|
| 247 |
-
from transformers import AutoModelForSeq2SeqLM
|
| 248 |
|
| 249 |
# Define Pydantic model for incoming request body
|
| 250 |
class MessageRequest(BaseModel):
|
|
@@ -288,6 +288,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
| 288 |
templates = Jinja2Templates(directory="static")
|
| 289 |
|
| 290 |
# Configure Llama index settings
|
|
|
|
| 291 |
Settings.llm = HuggingFaceLLM(
|
| 292 |
model_name="google/flan-t5-small",
|
| 293 |
tokenizer_name="google/flan-t5-small",
|
|
@@ -295,6 +296,7 @@ Settings.llm = HuggingFaceLLM(
|
|
| 295 |
max_new_tokens=256,
|
| 296 |
generate_kwargs={"temperature": 0.1, "do_sample": True},
|
| 297 |
model=AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small"),
|
|
|
|
| 298 |
device_map="auto" # Automatically use GPU if available, else CPU
|
| 299 |
)
|
| 300 |
Settings.embed_model = HuggingFaceEmbedding(
|
|
@@ -341,18 +343,15 @@ def split_name(full_name):
|
|
| 341 |
initialize() # Run initialization tasks
|
| 342 |
|
| 343 |
def handle_query(query):
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
)
|
| 354 |
-
]
|
| 355 |
-
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
|
| 356 |
|
| 357 |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
|
| 358 |
index = load_index_from_storage(storage_context)
|
|
@@ -361,7 +360,7 @@ def handle_query(query):
|
|
| 361 |
if past_query.strip():
|
| 362 |
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
|
| 363 |
|
| 364 |
-
query_engine = index.as_query_engine(text_qa_template=text_qa_template
|
| 365 |
answer = query_engine.query(query)
|
| 366 |
|
| 367 |
if hasattr(answer, "response"):
|
|
|
|
| 235 |
from fastapi import FastAPI, Request
|
| 236 |
from fastapi.responses import HTMLResponse
|
| 237 |
from fastapi.staticfiles import StaticFiles
|
| 238 |
+
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, Settings
|
| 239 |
from llama_index.llms.huggingface import HuggingFaceLLM
|
| 240 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 241 |
from pydantic import BaseModel
|
|
|
|
| 244 |
from fastapi.middleware.cors import CORSMiddleware
|
| 245 |
from fastapi.templating import Jinja2Templates
|
| 246 |
from simple_salesforce import Salesforce, SalesforceLogin
|
| 247 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 248 |
|
| 249 |
# Define Pydantic model for incoming request body
|
| 250 |
class MessageRequest(BaseModel):
|
|
|
|
| 288 |
templates = Jinja2Templates(directory="static")
|
| 289 |
|
| 290 |
# Configure Llama index settings
|
| 291 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
| 292 |
Settings.llm = HuggingFaceLLM(
|
| 293 |
model_name="google/flan-t5-small",
|
| 294 |
tokenizer_name="google/flan-t5-small",
|
|
|
|
| 296 |
max_new_tokens=256,
|
| 297 |
generate_kwargs={"temperature": 0.1, "do_sample": True},
|
| 298 |
model=AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small"),
|
| 299 |
+
tokenizer=tokenizer,
|
| 300 |
device_map="auto" # Automatically use GPU if available, else CPU
|
| 301 |
)
|
| 302 |
Settings.embed_model = HuggingFaceEmbedding(
|
|
|
|
| 343 |
initialize() # Run initialization tasks
|
| 344 |
|
| 345 |
def handle_query(query):
|
| 346 |
+
# Custom prompt template for flan-t5-small (no chat template)
|
| 347 |
+
text_qa_template = PromptTemplate(
|
| 348 |
+
"""
|
| 349 |
+
You are Clara, a Redfernstech chatbot. Provide accurate, concise answers (10-15 words) based on company data.
|
| 350 |
+
Context: {context_str}
|
| 351 |
+
Question: {query_str}
|
| 352 |
+
Answer:
|
| 353 |
+
"""
|
| 354 |
+
)
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
|
| 357 |
index = load_index_from_storage(storage_context)
|
|
|
|
| 360 |
if past_query.strip():
|
| 361 |
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
|
| 362 |
|
| 363 |
+
query_engine = index.as_query_engine(text_qa_template=text_qa_template)
|
| 364 |
answer = query_engine.query(query)
|
| 365 |
|
| 366 |
if hasattr(answer, "response"):
|