kdevoe commited on
Commit
684c258
·
verified ·
1 Parent(s): f0b621d

Using locally saved fine tuned model

Browse files
Files changed (1) hide show
  1. app.py +15 -21
app.py CHANGED
@@ -6,23 +6,23 @@ from langchain.memory import ConversationBufferMemory
6
  # Move model to device (GPU if available)
7
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
 
9
- # Load the tokenizer and model for DistilGPT-2
10
  tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
11
- model = GPT2LMHeadModel.from_pretrained("distilgpt2")
12
- model.to(device)
13
 
14
- # # Load summarization model (e.g., T5-small)
15
- # summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
16
- # summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)
 
 
 
 
 
 
 
 
17
 
18
- # def summarize_history(history):
19
- # input_ids = summarizer_tokenizer.encode(
20
- # "summarize: " + history,
21
- # return_tensors="pt"
22
- # ).to(device)
23
- # summary_ids = summarizer_model.generate(input_ids, max_length=50, min_length=25, length_penalty=5., num_beams=2)
24
- # summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
25
- # return summary
26
 
27
  # Set up conversational memory using LangChain's ConversationBufferMemory
28
  memory = ConversationBufferMemory()
@@ -32,10 +32,6 @@ def chat_with_distilgpt2(input_text):
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
35
- # # Summarize if history exceeds certain length
36
- # if len(conversation_history.split()) > 200:
37
- # conversation_history = summarize_history(conversation_history)
38
-
39
  # Combine the (possibly summarized) history with the current user input
40
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
41
 
@@ -50,9 +46,6 @@ def chat_with_distilgpt2(input_text):
50
  num_return_sequences=1,
51
  no_repeat_ngram_size=3,
52
  repetition_penalty=1.2,
53
- # temperature=0.9,
54
- # top_k=20,
55
- # top_p=0.8,
56
  early_stopping=True,
57
  pad_token_id=tokenizer.eos_token_id,
58
  eos_token_id=tokenizer.eos_token_id
@@ -79,3 +72,4 @@ interface = gr.Interface(
79
  interface.launch()
80
 
81
 
 
 
6
  # Move model to device (GPU if available)
7
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
 
9
+ # Load the tokenizer (use pre-trained tokenizer for GPT-2 family)
10
  tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
 
 
11
 
12
+ # Load the fine-tuned model from the local safetensors file
13
+ model_path = "./model.safetensors" # Path to your local model file
14
+ model = GPT2LMHeadModel.from_pretrained(
15
+ pretrained_model_name_or_path=None, # None because it's not from a model name
16
+ config="distilgpt2", # Specify the config for distilgpt2
17
+ local_files_only=True, # Only look for local files
18
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
19
+ )
20
+
21
+ # Load the safetensors weights
22
+ model.load_state_dict(torch.load(model_path, map_location=device))
23
 
24
+ # Move model to the device (GPU or CPU)
25
+ model.to(device)
 
 
 
 
 
 
26
 
27
  # Set up conversational memory using LangChain's ConversationBufferMemory
28
  memory = ConversationBufferMemory()
 
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
 
 
 
 
35
  # Combine the (possibly summarized) history with the current user input
36
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
37
 
 
46
  num_return_sequences=1,
47
  no_repeat_ngram_size=3,
48
  repetition_penalty=1.2,
 
 
 
49
  early_stopping=True,
50
  pad_token_id=tokenizer.eos_token_id,
51
  eos_token_id=tokenizer.eos_token_id
 
72
  interface.launch()
73
 
74
 
75
+