Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import random | |
| from datasets import load_dataset | |
| from peft import PeftModel | |
| import os | |
| title = "Gemma-2b SciQ" | |
| description = """ | |
| Gemma-2b fine-tuned on SciQ | |
| """ | |
| article = "GitHub repository: https://github.com/P-Zande/nlp-team-4" | |
| model_id = "google/gemma-2b" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) | |
| base_model = AutoModelForCausalLM.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) | |
| model = PeftModel.from_pretrained(base_model, "./") | |
| model = model.merge_and_unload() | |
| dataset = load_dataset("allenai/sciq") | |
| random_test_samples = dataset["test"].select(range(5)) | |
| examples = [] | |
| for row in random_test_samples: | |
| examples.append([row['support'], ""]) | |
| examples.append([row['support'], row['correct_answer']]) | |
| def predict(context = "", answer = ""): | |
| formatted = context.replace('\n', ' ') + "\n" | |
| if answer != "": | |
| formatted = context.replace('\n', ' ') + "\n" + answer.replace('\n', ' ') + "\n" | |
| inputs = tokenizer(formatted, return_tensors="pt") | |
| outputs = model.generate(**inputs, max_new_tokens=100) | |
| decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| split_outputs = decoded_outputs.split("\n") | |
| if len(split_outputs) == 6: | |
| return ( | |
| split_outputs[0], | |
| split_outputs[1], | |
| split_outputs[2], | |
| split_outputs[3], | |
| split_outputs[4], | |
| split_outputs[5], | |
| ) | |
| return ("ERROR: " + decoded_outputs, None, None, None, None, None) | |
| support_gr = gr.TextArea( | |
| label="Context", | |
| value="Bananas are yellow and curved." | |
| ) | |
| answer_gr = gr.Text( | |
| label="Answer (optional)", | |
| value="yellow" | |
| ) | |
| context_output_gr = gr.Text( | |
| label="Context" | |
| ) | |
| answer_output_gr = gr.Text( | |
| label="Answer" | |
| ) | |
| question_output_gr = gr.Text( | |
| label="Question" | |
| ) | |
| distractor1_output_gr = gr.Text( | |
| label="Distractor 1" | |
| ) | |
| distractor2_output_gr = gr.Text( | |
| label="Distractor 2" | |
| ) | |
| distractor3_output_gr = gr.Text( | |
| label="Distractor 3" | |
| ) | |
| gr.Interface( | |
| fn=predict, | |
| inputs=[support_gr, answer_gr], | |
| outputs=[context_output_gr, answer_output_gr, question_output_gr, distractor1_output_gr, distractor2_output_gr, distractor3_output_gr], | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| ).launch() |