Spaces:
Running
Running
| import spaces | |
| import gradio as gr | |
| from transformers import pipeline, TextIteratorStreamer | |
| import torch | |
| import threading | |
| import os | |
| # Load model and tokenizer | |
| model_name = os.getenv("MODEL_ID") | |
| pipe = pipeline("text-generation", model=model_name, device=0) | |
| tokenizer = pipe.tokenizer | |
| model = pipe.model | |
| # Fixed generation config | |
| MAX_TOKENS = 3000 | |
| TEMPERATURE = 0.1 | |
| TOP_P = 0.9 | |
| def respond_stream(population, intervention, comparison, outcome, study_design, summary, title, abstract): | |
| # Validate required fields | |
| if not title.strip() or not abstract.strip(): | |
| return "❌ Error: Title and Abstract are required." | |
| criteria_parts = [] | |
| if population.strip(): | |
| criteria_parts.append(f"Population of interest = {population.strip()}") | |
| if intervention.strip(): | |
| criteria_parts.append(f"Intervention/exposure of interest = {intervention.strip()}") | |
| if comparison.strip(): | |
| criteria_parts.append(f"Comparison of interest = {comparison.strip()}") | |
| if outcome.strip(): | |
| criteria_parts.append(f"Outcome of interest = {outcome.strip()}") | |
| if study_design.strip(): | |
| criteria_parts.append(f"Study design of interest = {study_design.strip()}") | |
| if not criteria_parts: | |
| return "❌ Error: At least one of the five PICOS criteria must be filled." | |
| # Build instruction section | |
| instruction = "Instruction: " + "\n".join(criteria_parts) | |
| # Construct full prompt | |
| prompt = instruction | |
| if summary.strip(): | |
| prompt += f"\n\nPICOS Summary: {summary.strip()}" | |
| prompt += f"\n\nTitle: {title.strip()}\nAbstract: {abstract.strip()}" | |
| # Wrap into message for chat template | |
| messages = [{"role": "user", "content": prompt}] | |
| prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| # Tokenize and prepare streamer | |
| inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda") | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| input_ids=inputs["input_ids"], | |
| streamer=streamer, | |
| max_new_tokens=MAX_TOKENS, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| partial_text = "" | |
| for token in streamer: | |
| partial_text += token | |
| yield partial_text | |
| # Build Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## PICO screener") | |
| with gr.Column(): | |
| population = gr.Textbox(label="Population of interest", lines=1) | |
| intervention = gr.Textbox(label="Intervention/exposure of interest", lines=1) | |
| comparison = gr.Textbox(label="Comparison of interest", lines=1) | |
| outcome = gr.Textbox(label="Outcome of interest", lines=1) | |
| study_design = gr.Textbox(label="Study design of interest", lines=1) | |
| with gr.Column(): | |
| summary = gr.Textbox(label="PICOS Summary (optional)", lines=4) | |
| title = gr.Textbox(label="Title", lines=2, placeholder="Required") | |
| abstract = gr.Textbox(label="Abstract", lines=10, placeholder="Required") | |
| output_box = gr.Textbox(label="Model Response", lines=15, interactive=False) | |
| generate_btn = gr.Button("Generate") | |
| generate_btn.click( | |
| fn=respond_stream, | |
| inputs=[population, intervention, comparison, outcome, study_design, summary, title, abstract], | |
| outputs=[output_box] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |