Spaces:
Running
Running
File size: 3,604 Bytes
73f480d 4e569e3 3c83c25 54fb4b6 0cbb7c1 76812ff f3e8888 3c83c25 80fe42a 3c83c25 1d3994c 76812ff 73f480d 3c83c25 76812ff 73f480d 76812ff 3c83c25 226053b 3c83c25 2cf80ad 4e569e3 3c83c25 3445414 76812ff 3445414 3c83c25 73f480d 5a6459d 73f480d 3c83c25 73f480d 3445414 4982a76 d7fea83 73f480d 2cf80ad e584b91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import spaces
import gradio as gr
from transformers import pipeline, TextIteratorStreamer
import torch
import threading
import os
# Load model and tokenizer
model_name = os.getenv("MODEL_ID")
pipe = pipeline("text-generation", model=model_name, device=0)
tokenizer = pipe.tokenizer
model = pipe.model
# Fixed generation config
MAX_TOKENS = 3000
TEMPERATURE = 0.1
TOP_P = 0.9
@spaces.GPU
def respond_stream(population, intervention, comparison, outcome, study_design, summary, title, abstract):
# Validate required fields
if not title.strip() or not abstract.strip():
return "❌ Error: Title and Abstract are required."
criteria_parts = []
if population.strip():
criteria_parts.append(f"Population of interest = {population.strip()}")
if intervention.strip():
criteria_parts.append(f"Intervention/exposure of interest = {intervention.strip()}")
if comparison.strip():
criteria_parts.append(f"Comparison of interest = {comparison.strip()}")
if outcome.strip():
criteria_parts.append(f"Outcome of interest = {outcome.strip()}")
if study_design.strip():
criteria_parts.append(f"Study design of interest = {study_design.strip()}")
if not criteria_parts:
return "❌ Error: At least one of the five PICOS criteria must be filled."
# Build instruction section
instruction = "Instruction: " + "\n".join(criteria_parts)
# Construct full prompt
prompt = instruction
if summary.strip():
prompt += f"\n\nPICOS Summary: {summary.strip()}"
prompt += f"\n\nTitle: {title.strip()}\nAbstract: {abstract.strip()}"
# Wrap into message for chat template
messages = [{"role": "user", "content": prompt}]
prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Tokenize and prepare streamer
inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs["input_ids"],
streamer=streamer,
max_new_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_text = ""
for token in streamer:
partial_text += token
yield partial_text
# Build Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## PICO screener")
with gr.Column():
population = gr.Textbox(label="Population of interest", lines=1)
intervention = gr.Textbox(label="Intervention/exposure of interest", lines=1)
comparison = gr.Textbox(label="Comparison of interest", lines=1)
outcome = gr.Textbox(label="Outcome of interest", lines=1)
study_design = gr.Textbox(label="Study design of interest", lines=1)
with gr.Column():
summary = gr.Textbox(label="PICOS Summary (optional)", lines=4)
title = gr.Textbox(label="Title", lines=2, placeholder="Required")
abstract = gr.Textbox(label="Abstract", lines=10, placeholder="Required")
output_box = gr.Textbox(label="Model Response", lines=15, interactive=False)
generate_btn = gr.Button("Generate")
generate_btn.click(
fn=respond_stream,
inputs=[population, intervention, comparison, outcome, study_design, summary, title, abstract],
outputs=[output_box]
)
# Launch the app
if __name__ == "__main__":
demo.launch() |