File size: 3,604 Bytes
73f480d
4e569e3
3c83c25
 
 
54fb4b6
0cbb7c1
76812ff
f3e8888
3c83c25
 
 
80fe42a
3c83c25
 
 
 
 
1d3994c
76812ff
 
 
73f480d
3c83c25
76812ff
 
 
 
 
 
 
 
 
 
 
 
73f480d
 
76812ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c83c25
 
 
 
 
 
 
 
226053b
3c83c25
2cf80ad
4e569e3
3c83c25
 
 
3445414
76812ff
 
3445414
3c83c25
73f480d
 
5a6459d
73f480d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c83c25
 
 
73f480d
3445414
4982a76
d7fea83
73f480d
2cf80ad
e584b91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import spaces
import gradio as gr
from transformers import pipeline, TextIteratorStreamer
import torch
import threading
import os

# Load model and tokenizer
model_name = os.getenv("MODEL_ID")
pipe = pipeline("text-generation", model=model_name, device=0)
tokenizer = pipe.tokenizer
model = pipe.model

# Fixed generation config
MAX_TOKENS = 3000
TEMPERATURE = 0.1
TOP_P = 0.9

@spaces.GPU
def respond_stream(population, intervention, comparison, outcome, study_design, summary, title, abstract):
    # Validate required fields
    if not title.strip() or not abstract.strip():
        return "❌ Error: Title and Abstract are required."

    criteria_parts = []
    if population.strip():
        criteria_parts.append(f"Population of interest = {population.strip()}")
    if intervention.strip():
        criteria_parts.append(f"Intervention/exposure of interest = {intervention.strip()}")
    if comparison.strip():
        criteria_parts.append(f"Comparison of interest = {comparison.strip()}")
    if outcome.strip():
        criteria_parts.append(f"Outcome of interest = {outcome.strip()}")
    if study_design.strip():
        criteria_parts.append(f"Study design of interest = {study_design.strip()}")

    if not criteria_parts:
        return "❌ Error: At least one of the five PICOS criteria must be filled."

    # Build instruction section
    instruction = "Instruction: " + "\n".join(criteria_parts)

    # Construct full prompt
    prompt = instruction
    if summary.strip():
        prompt += f"\n\nPICOS Summary: {summary.strip()}"
    prompt += f"\n\nTitle: {title.strip()}\nAbstract: {abstract.strip()}"

    # Wrap into message for chat template
    messages = [{"role": "user", "content": prompt}]
    prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # Tokenize and prepare streamer
    inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = dict(
        input_ids=inputs["input_ids"],
        streamer=streamer,
        max_new_tokens=MAX_TOKENS,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    partial_text = ""
    for token in streamer:
        partial_text += token
        yield partial_text

# Build Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## PICO screener")

    with gr.Column():
        population = gr.Textbox(label="Population of interest", lines=1)
        intervention = gr.Textbox(label="Intervention/exposure of interest", lines=1)
        comparison = gr.Textbox(label="Comparison of interest", lines=1)
        outcome = gr.Textbox(label="Outcome of interest", lines=1)
        study_design = gr.Textbox(label="Study design of interest", lines=1)

    with gr.Column():
        summary = gr.Textbox(label="PICOS Summary (optional)", lines=4)
        title = gr.Textbox(label="Title", lines=2, placeholder="Required")
        abstract = gr.Textbox(label="Abstract", lines=10, placeholder="Required")

    output_box = gr.Textbox(label="Model Response", lines=15, interactive=False)
    generate_btn = gr.Button("Generate")

    generate_btn.click(
        fn=respond_stream,
        inputs=[population, intervention, comparison, outcome, study_design, summary, title, abstract],
        outputs=[output_box]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()