Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,8 +3,10 @@ import json
|
|
| 3 |
import streamlit as st
|
| 4 |
import faiss
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
-
from transformers import
|
|
|
|
| 8 |
from reportlab.lib.pagesizes import A4
|
| 9 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
| 10 |
from reportlab.lib.styles import getSampleStyleSheet
|
|
@@ -54,19 +56,32 @@ def retrieve_milestone(user_input):
|
|
| 54 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
| 55 |
|
| 56 |
# Initialize IBM Granite Model
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def generate_response(user_input, child_age):
|
| 63 |
relevant_milestone = retrieve_milestone(user_input)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# Streamlit UI Styling
|
| 72 |
st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
|
|
|
|
| 3 |
import streamlit as st
|
| 4 |
import faiss
|
| 5 |
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 9 |
+
from peft import PeftModel
|
| 10 |
from reportlab.lib.pagesizes import A4
|
| 11 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
| 12 |
from reportlab.lib.styles import getSampleStyleSheet
|
|
|
|
| 56 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
| 57 |
|
| 58 |
# Initialize IBM Granite Model
|
| 59 |
+
BASE_NAME = "ibm-granite/granite-3.0-8b-instruct"
|
| 60 |
+
LORA_NAME = "ibm-granite/granite-rag-3.0-8b-lora"
|
| 61 |
+
|
| 62 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 63 |
+
|
| 64 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_NAME, padding_side='left', trust_remote_code=True)
|
| 65 |
+
model_base = AutoModelForCausalLM.from_pretrained(BASE_NAME, device_map="auto")
|
| 66 |
+
model_rag = PeftModel.from_pretrained(model_base, LORA_NAME)
|
| 67 |
|
| 68 |
def generate_response(user_input, child_age):
|
| 69 |
relevant_milestone = retrieve_milestone(user_input)
|
| 70 |
+
question_chat = [
|
| 71 |
+
{
|
| 72 |
+
"role": "system",
|
| 73 |
+
"content": "{\"instruction\": \"Respond to the user's latest question based solely on the information provided in the documents. Ensure that your response is strictly aligned with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data. Make sure that your response follows the attributes mentioned in the 'meta' field.\", \"documents\": [{\"doc_id\": 1, \"text\": \"The child is {child_age} months old. Based on the given traits: {user_input}, determine whether the child is meeting expected milestones. Relevant milestone: {relevant_milestone}. If there are any concerns, suggest steps the parents can take.\"}], \"meta\": {\"hallucination_tags\": true, \"citations\": true}}"
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"role": "user",
|
| 77 |
+
"content": user_input
|
| 78 |
+
}
|
| 79 |
+
]
|
| 80 |
+
input_text = tokenizer.apply_chat_template(question_chat, tokenize=False, add_generation_prompt=True)
|
| 81 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
| 82 |
+
output = model_rag.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=500)
|
| 83 |
+
output_text = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 84 |
+
return output_text
|
| 85 |
|
| 86 |
# Streamlit UI Styling
|
| 87 |
st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
|