safiaa02 commited on
Commit
63b8d64
·
verified ·
1 Parent(s): 7980c0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -3,8 +3,10 @@ import json
3
  import streamlit as st
4
  import faiss
5
  import numpy as np
 
6
  from sentence_transformers import SentenceTransformer
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
8
  from reportlab.lib.pagesizes import A4
9
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
10
  from reportlab.lib.styles import getSampleStyleSheet
@@ -54,19 +56,32 @@ def retrieve_milestone(user_input):
54
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
55
 
56
  # Initialize IBM Granite Model
57
- model_name = "ibm-granite/granite-rag-3.0-8b-lora"
58
- tokenizer = AutoTokenizer.from_pretrained(model_name)
59
- lm_model = AutoModelForCausalLM.from_pretrained(model_name)
60
- generation_pipeline = pipeline("text-generation", model=lm_model, tokenizer=tokenizer, max_length=512)
 
 
 
 
61
 
62
  def generate_response(user_input, child_age):
63
  relevant_milestone = retrieve_milestone(user_input)
64
- prompt = (f"The child is {child_age} months old. Based on the given traits: {user_input}, "
65
- f"determine whether the child is meeting expected milestones. "
66
- f"Relevant milestone: {relevant_milestone}. "
67
- "If there are any concerns, suggest steps the parents can take. ")
68
- response = generation_pipeline(prompt)
69
- return response[0]['generated_text']
 
 
 
 
 
 
 
 
 
70
 
71
  # Streamlit UI Styling
72
  st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
 
3
  import streamlit as st
4
  import faiss
5
  import numpy as np
6
+ import torch
7
  from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from peft import PeftModel
10
  from reportlab.lib.pagesizes import A4
11
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
12
  from reportlab.lib.styles import getSampleStyleSheet
 
56
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
57
 
58
  # Initialize IBM Granite Model
59
+ BASE_NAME = "ibm-granite/granite-3.0-8b-instruct"
60
+ LORA_NAME = "ibm-granite/granite-rag-3.0-8b-lora"
61
+
62
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained(BASE_NAME, padding_side='left', trust_remote_code=True)
65
+ model_base = AutoModelForCausalLM.from_pretrained(BASE_NAME, device_map="auto")
66
+ model_rag = PeftModel.from_pretrained(model_base, LORA_NAME)
67
 
68
  def generate_response(user_input, child_age):
69
  relevant_milestone = retrieve_milestone(user_input)
70
+ question_chat = [
71
+ {
72
+ "role": "system",
73
+ "content": "{\"instruction\": \"Respond to the user's latest question based solely on the information provided in the documents. Ensure that your response is strictly aligned with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data. Make sure that your response follows the attributes mentioned in the 'meta' field.\", \"documents\": [{\"doc_id\": 1, \"text\": \"The child is {child_age} months old. Based on the given traits: {user_input}, determine whether the child is meeting expected milestones. Relevant milestone: {relevant_milestone}. If there are any concerns, suggest steps the parents can take.\"}], \"meta\": {\"hallucination_tags\": true, \"citations\": true}}"
74
+ },
75
+ {
76
+ "role": "user",
77
+ "content": user_input
78
+ }
79
+ ]
80
+ input_text = tokenizer.apply_chat_template(question_chat, tokenize=False, add_generation_prompt=True)
81
+ inputs = tokenizer(input_text, return_tensors="pt")
82
+ output = model_rag.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=500)
83
+ output_text = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
84
+ return output_text
85
 
86
  # Streamlit UI Styling
87
  st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")